""" from https://github.com/keithito/tacotron """
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import re
import logging
logger = logging.getLogger(__name__)
valid_symbols = [
"AA",
"AA0",
"AA1",
"AA2",
"AE",
"AE0",
"AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH",
]
"""
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English. For other data, you can modify _characters. See TRAINING_DATA.md for details.
"""
_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same
# as uppercase letters):
_arpabet = ["@" + s for s in valid_symbols]
# Export all symbols:
symbols = (
[_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
)
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
[docs]
def expand_abbreviations(text):
"""expand abbreviations pre-defined
"""
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
# def expand_numbers(text):
# return normalize_numbers(text)
[docs]
def lowercase(text):
"""lowercase the text
"""
return text.lower()
[docs]
def collapse_whitespace(text):
"""Replaces whitespace by " " in the text
"""
return re.sub(_whitespace_re, " ", text)
[docs]
def convert_to_ascii(text):
"""Converts text to ascii
"""
text_encoded = text.encode("ascii", "ignore")
return text_encoded.decode()
[docs]
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration.
"""
text = lowercase(text)
text = collapse_whitespace(text)
return text
[docs]
def german_cleaners(text):
"""Pipeline for German text, that collapses whitespace without transliteration.
"""
text = collapse_whitespace(text)
return text
[docs]
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII.
"""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
[docs]
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion.
"""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
[docs]
def text_to_sequence(text, cleaner_names):
"""Returns a list of integers corresponding to the symbols in the text.
Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Arguments
---------
text : str
string to convert to a sequence
cleaner_names : list
names of the cleaner functions to run the text through
"""
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
return sequence
[docs]
def sequence_to_text(sequence):
"""Converts a sequence of IDs back to a string
"""
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
"""apply different cleaning pipeline according to cleaner_names
"""
for name in cleaner_names:
if name == "english_cleaners":
cleaner = english_cleaners
if name == "transliteration_cleaners":
cleaner = transliteration_cleaners
if name == "basic_cleaners":
cleaner = basic_cleaners
if name == "german_cleaners":
cleaner = german_cleaners
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
"""convert symbols to sequence
"""
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
"""Prepend "@" to ensure uniqueness
"""
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
"""whether to keep a certain symbol
"""
return s in _symbol_to_id and s != "_" and s != "~"
def _g2p_keep_punctuations(g2p_model, text):
"""do grapheme to phoneme and keep the punctuations between the words
Arguments
---------
g2p_model: speechbrain.inference.text g2p model
text: string
the input text
Example
-------
>>> from speechbrain.inference.text import GraphemeToPhoneme
>>> g2p_model = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") # doctest: +SKIP
>>> from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations # doctest: +SKIP
>>> text = "Hi, how are you?" # doctest: +SKIP
>>> _g2p_keep_punctuations(g2p_model, text) # doctest: +SKIP
['HH', 'AY', ',', ' ', 'HH', 'AW', ' ', 'AA', 'R', ' ', 'Y', 'UW', '?']
"""
# find the words where a "-" or "'" or "." or ":" appears in the middle
special_words = re.findall(r"\w+[-':\.][-':\.\w]*\w+", text)
# remove intra-word punctuations ("-':."), this does not change the output of speechbrain g2p
for special_word in special_words:
rmp = special_word.replace("-", "")
rmp = rmp.replace("'", "")
rmp = rmp.replace(":", "")
rmp = rmp.replace(".", "")
text = text.replace(special_word, rmp)
# keep inter-word punctuations
all_ = re.findall(r"[\w]+|[-!'(),.:;? ]", text)
try:
phonemes = g2p_model(text)
except RuntimeError:
logger.info(f"error with text: {text}")
quit()
word_phonemes = "-".join(phonemes).split(" ")
phonemes_with_punc = []
count = 0
try:
# if the g2p model splits the words correctly
for i in all_:
if i not in "-!'(),.:;? ":
phonemes_with_punc.extend(word_phonemes[count].split("-"))
count += 1
else:
phonemes_with_punc.append(i)
except IndexError:
# sometimes the g2p model cannot split the words correctly
logger.warning(
f"Do g2p word by word because of unexpected ouputs from g2p for text: {text}"
)
for i in all_:
if i not in "-!'(),.:;? ":
p = g2p_model.g2p(i)
p_without_space = [i for i in p if i != " "]
phonemes_with_punc.extend(p_without_space)
else:
phonemes_with_punc.append(i)
while "" in phonemes_with_punc:
phonemes_with_punc.remove("")
return phonemes_with_punc