""" Specifies the inference interfaces for Text-To-Speech (TTS) modules.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import random
import re
import torch
import torchaudio
import speechbrain
from speechbrain.inference.classifiers import EncoderClassifier
from speechbrain.inference.encoders import MelSpectrogramEncoder
from speechbrain.inference.interfaces import Pretrained
from speechbrain.inference.text import GraphemeToPhoneme
from speechbrain.utils.fetching import fetch
from speechbrain.utils.logger import get_logger
from speechbrain.utils.text_to_sequence import text_to_sequence
logger = get_logger(__name__)
[docs]
class Tacotron2(Pretrained):
"""
A ready-to-use wrapper for Tacotron2 (text -> mel_spec).
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts)
>>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
>>> items = [
... "A quick brown fox jumped over the lazy dog",
... "How much wood would a woodchuck chuck?",
... "Never odd or even"
... ]
>>> mel_outputs, mel_lengths, alignments = tacotron2.encode_batch(items)
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Initialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
>>> # Running the TTS
>>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_output)
"""
HPARAMS_NEEDED = ["model", "text_to_sequence"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_cleaners = getattr(
self.hparams, "text_cleaners", ["english_cleaners"]
)
self.infer = self.hparams.model.infer
[docs]
def text_to_seq(self, txt):
"""Encodes raw text into a tensor with a customer text-to-sequence function"""
sequence = self.hparams.text_to_sequence(txt, self.text_cleaners)
return sequence, len(sequence)
[docs]
def encode_batch(self, texts):
"""Computes mel-spectrogram for a list of texts
Texts must be sorted in decreasing order on their lengths
Arguments
---------
texts: List[str]
texts to be encoded into spectrogram
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
with torch.no_grad():
inputs = [
{
"text_sequences": torch.tensor(
self.text_to_seq(item)[0], device=self.device
)
}
for item in texts
]
inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
lens = [self.text_to_seq(item)[1] for item in texts]
assert lens == sorted(
lens, reverse=True
), "input lengths must be sorted in decreasing order"
input_lengths = torch.tensor(lens, device=self.device)
mel_outputs_postnet, mel_lengths, alignments = self.infer(
inputs.text_sequences.data, input_lengths
)
return mel_outputs_postnet, mel_lengths, alignments
[docs]
def encode_text(self, text):
"""Runs inference for a single text str"""
return self.encode_batch([text])
[docs]
def forward(self, texts):
"Encodes the input texts."
return self.encode_batch(texts)
[docs]
class MSTacotron2(Pretrained):
"""
A ready-to-use wrapper for Zero-Shot Multi-Speaker Tacotron2.
For voice cloning: (text, reference_audio) -> (mel_spec).
For generating a random speaker voice: (text) -> (mel_spec).
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> mstacotron2 = MSTacotron2.from_hparams(source="speechbrain/tts-mstacotron2-libritts", savedir=tmpdir_tts) # doctest: +SKIP
>>> # Sample rate of the reference audio must be greater or equal to the sample rate of the speaker embedding model
>>> reference_audio_path = "tests/samples/single-mic/example1.wav"
>>> input_text = "Mary had a little lamb."
>>> mel_output, mel_length, alignment = mstacotron2.clone_voice(input_text, reference_audio_path) # doctest: +SKIP
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Initialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-libritts-22050Hz", savedir=tmpdir_vocoder) # doctest: +SKIP
>>> # Running the TTS
>>> mel_output, mel_length, alignment = mstacotron2.clone_voice(input_text, reference_audio_path) # doctest: +SKIP
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_output) # doctest: +SKIP
>>> # For generating a random speaker voice, use the following
>>> mel_output, mel_length, alignment = mstacotron2.generate_random_voice(input_text) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["model"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_cleaners = ["english_cleaners"]
self.infer = self.hparams.model.infer
self.custom_mel_spec_encoder = self.hparams.custom_mel_spec_encoder
self.g2p = GraphemeToPhoneme.from_hparams(
self.hparams.g2p, run_opts={"device": self.device}
)
self.spk_emb_encoder = None
if self.custom_mel_spec_encoder:
self.spk_emb_encoder = MelSpectrogramEncoder.from_hparams(
source=self.hparams.spk_emb_encoder,
run_opts={"device": self.device},
)
else:
self.spk_emb_encoder = EncoderClassifier.from_hparams(
source=self.hparams.spk_emb_encoder,
run_opts={"device": self.device},
)
def __text_to_seq(self, txt):
"""Encodes raw text into a tensor with a customer text-to-sequence function"""
sequence = text_to_sequence(txt, self.text_cleaners)
return sequence, len(sequence)
[docs]
def clone_voice(self, texts, audio_path):
"""
Generates mel-spectrogram using input text and reference audio
Arguments
---------
texts : str or list
Input text
audio_path : str
Reference audio
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
# Loads audio
ref_signal, signal_sr = torchaudio.load(audio_path)
# Resamples the audio if required
if signal_sr != self.hparams.spk_emb_sample_rate:
ref_signal = torchaudio.functional.resample(
ref_signal, signal_sr, self.hparams.spk_emb_sample_rate
)
ref_signal = ref_signal.to(self.device)
# Computes speaker embedding
if self.custom_mel_spec_encoder:
spk_emb = self.spk_emb_encoder.encode_waveform(ref_signal)
else:
spk_emb = self.spk_emb_encoder.encode_batch(ref_signal)
spk_emb = spk_emb.squeeze(0)
# Converts input texts into the corresponding phoneme sequences
if isinstance(texts, str):
texts = [texts]
phoneme_seqs = self.g2p(texts)
for i in range(len(phoneme_seqs)):
phoneme_seqs[i] = " ".join(phoneme_seqs[i])
phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}"
# Repeats the speaker embedding to match the number of input texts
spk_embs = spk_emb.repeat(len(texts), 1)
# Calls __encode_batch to generate the mel-spectrograms
return self.__encode_batch(phoneme_seqs, spk_embs)
[docs]
def generate_random_voice(self, texts):
"""
Generates mel-spectrogram using input text and a random speaker voice
Arguments
---------
texts : str or list
Input text
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
spk_emb = self.__sample_random_speaker().float()
spk_emb = spk_emb.to(self.device)
# Converts input texts into the corresponding phoneme sequences
if isinstance(texts, str):
texts = [texts]
phoneme_seqs = self.g2p(texts)
for i in range(len(phoneme_seqs)):
phoneme_seqs[i] = " ".join(phoneme_seqs[i])
phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}"
# Repeats the speaker embedding to match the number of input texts
spk_embs = spk_emb.repeat(len(texts), 1)
# Calls __encode_batch to generate the mel-spectrograms
return self.__encode_batch(phoneme_seqs, spk_embs)
def __encode_batch(self, texts, spk_embs):
"""Computes mel-spectrograms for a list of texts
Texts are sorted in decreasing order on their lengths
Arguments
---------
texts: List[str]
texts to be encoded into spectrogram
spk_embs: torch.Tensor
speaker embeddings
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
with torch.no_grad():
inputs = [
{
"text_sequences": torch.tensor(
self.__text_to_seq(item)[0], device=self.device
)
}
for item in texts
]
inputs = sorted(
inputs,
key=lambda x: x["text_sequences"].size()[0],
reverse=True,
)
lens = [entry["text_sequences"].size()[0] for entry in inputs]
inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
assert lens == sorted(
lens, reverse=True
), "input lengths must be sorted in decreasing order"
input_lengths = torch.tensor(lens, device=self.device)
mel_outputs_postnet, mel_lengths, alignments = self.infer(
inputs.text_sequences.data, spk_embs, input_lengths
)
return mel_outputs_postnet, mel_lengths, alignments
def __sample_random_speaker(self):
"""Samples a random speaker embedding from a pretrained GMM
Returns
-------
x: torch.Tensor
A randomly sampled speaker embedding
"""
# Fetches and Loads GMM trained on speaker embeddings
speaker_gmm_local_path = fetch(
filename=self.hparams.random_speaker_sampler,
source=self.hparams.random_speaker_sampler_source,
savedir=self.hparams.pretrainer.collect_in,
)
random_speaker_gmm = torch.load(speaker_gmm_local_path)
gmm_n_components = random_speaker_gmm["gmm_n_components"]
gmm_means = random_speaker_gmm["gmm_means"]
gmm_covariances = random_speaker_gmm["gmm_covariances"]
# Randomly selects a speaker
counts = torch.zeros(gmm_n_components)
counts[random.randint(0, gmm_n_components - 1)] = 1
x = torch.empty(0, device=counts.device)
# Samples an embedding for the speaker
for k in torch.arange(gmm_n_components)[counts > 0]:
# Considers full covariance type
d_k = torch.distributions.multivariate_normal.MultivariateNormal(
gmm_means[k], gmm_covariances[k]
)
x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))])
x = torch.cat((x, x_k), dim=0)
return x
[docs]
class FastSpeech2(Pretrained):
"""
A ready-to-use wrapper for Fastspeech2 (text -> mel_spec).
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir=tmpdir_tts) # doctest: +SKIP
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> items = [
... "A quick brown fox jumped over the lazy dog",
... "How much wood would a woodchuck chuck?",
... "Never odd or even"
... ]
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items) # doctest: +SKIP
>>>
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Initialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder) # doctest: +SKIP
>>> # Running the TTS
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["spn_predictor", "model", "input_encoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
lexicon = self.hparams.lexicon
lexicon = ["@@"] + lexicon
self.input_encoder = self.hparams.input_encoder
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
self.input_encoder.add_unk()
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
self.spn_token_encoded = (
self.input_encoder.encode_sequence_torch(["spn"]).int().item()
)
[docs]
def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Computes mel-spectrogram for a list of texts
Arguments
---------
texts: List[str]
texts to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
# Preprocessing required at the inference time for the input text
# "label" below contains input text
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
# "last_phonemes_combined" is used to indicate whether the index position is for a last phoneme of a word
# "punc_positions" is used to add back the silence for punctuations
phoneme_labels = list()
last_phonemes_combined = list()
punc_positions = list()
for label in texts:
phoneme_label = list()
last_phonemes = list()
punc_position = list()
words = label.split()
words = [word.strip() for word in words]
words_phonemes = self.g2p(words)
for i in range(len(words_phonemes)):
words_phonemes_seq = words_phonemes[i]
for phoneme in words_phonemes_seq:
if not phoneme.isspace():
phoneme_label.append(phoneme)
last_phonemes.append(0)
punc_position.append(0)
last_phonemes[-1] = 1
if words[i][-1] in ":;-,.!?":
punc_position[-1] = 1
phoneme_labels.append(phoneme_label)
last_phonemes_combined.append(last_phonemes)
punc_positions.append(punc_position)
# Inserts silent phonemes in the input phoneme sequence
all_tokens_with_spn = list()
max_seq_len = -1
for i in range(len(phoneme_labels)):
phoneme_label = phoneme_labels[i]
token_seq = (
self.input_encoder.encode_sequence_torch(phoneme_label)
.int()
.to(self.device)
)
last_phonemes = torch.LongTensor(last_phonemes_combined[i]).to(
self.device
)
# Runs the silent phoneme predictor
spn_preds = (
self.hparams.modules["spn_predictor"]
.infer(token_seq.unsqueeze(0), last_phonemes.unsqueeze(0))
.int()
)
spn_to_add = torch.nonzero(spn_preds).reshape(-1).tolist()
for j in range(len(punc_positions[i])):
if punc_positions[i][j] == 1:
spn_to_add.append(j)
tokens_with_spn = list()
for token_idx in range(token_seq.shape[0]):
tokens_with_spn.append(token_seq[token_idx].item())
if token_idx in spn_to_add:
tokens_with_spn.append(self.spn_token_encoded)
tokens_with_spn = torch.LongTensor(tokens_with_spn).to(self.device)
all_tokens_with_spn.append(tokens_with_spn)
if max_seq_len < tokens_with_spn.shape[-1]:
max_seq_len = tokens_with_spn.shape[-1]
# "tokens_with_spn_tensor" holds the input phoneme sequence with silent phonemes
tokens_with_spn_tensor_padded = torch.LongTensor(
len(texts), max_seq_len
).to(self.device)
tokens_with_spn_tensor_padded.zero_()
for seq_idx, seq in enumerate(all_tokens_with_spn):
tokens_with_spn_tensor_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_with_spn_tensor_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
[docs]
def encode_phoneme(
self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Computes mel-spectrogram for a list of phoneme sequences
Arguments
---------
phonemes: List[List[str]]
phonemes to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
all_tokens = []
max_seq_len = -1
for phoneme in phonemes:
token_seq = (
self.input_encoder.encode_sequence_torch(phoneme)
.int()
.to(self.device)
)
if max_seq_len < token_seq.shape[-1]:
max_seq_len = token_seq.shape[-1]
all_tokens.append(token_seq)
tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
self.device
)
tokens_padded.zero_()
for seq_idx, seq in enumerate(all_tokens):
tokens_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
[docs]
def encode_batch(
self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
tokens_padded : torch.Tensor
A sequence of encoded phonemes to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
post_mel_outputs : torch.Tensor
durations : torch.Tensor
pitch : torch.Tensor
energy : torch.Tensor
"""
with torch.no_grad():
(
_,
post_mel_outputs,
durations,
pitch,
_,
energy,
_,
_,
) = self.hparams.model(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
# Transposes to make in compliant with HiFI GAN expected format
post_mel_outputs = post_mel_outputs.transpose(-1, 1)
return post_mel_outputs, durations, pitch, energy
[docs]
def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
text : str
A text to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
Encoded text
"""
return self.encode_text(
[text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
)
[docs]
class FastSpeech2InternalAlignment(Pretrained):
"""
A ready-to-use wrapper for Fastspeech2 with internal alignment(text -> mel_spec).
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> fastspeech2 = FastSpeech2InternalAlignment.from_hparams(source="speechbrain/tts-fastspeech2-internal-alignment-ljspeech", savedir=tmpdir_tts) # doctest: +SKIP
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> items = [
... "A quick brown fox jumped over the lazy dog",
... "How much wood would a woodchuck chuck?",
... "Never odd or even"
... ]
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items) # doctest: +SKIP
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Initialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder) # doctest: +SKIP
>>> # Running the TTS
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["model", "input_encoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
lexicon = self.hparams.lexicon
lexicon = ["@@"] + lexicon
self.input_encoder = self.hparams.input_encoder
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
self.input_encoder.add_unk()
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
[docs]
def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Computes mel-spectrogram for a list of texts
Arguments
---------
texts: List[str]
texts to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
# Preprocessing required at the inference time for the input text
# "label" below contains input text
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
phoneme_labels = list()
max_seq_len = -1
for label in texts:
phonemes_with_punc = self._g2p_keep_punctuations(self.g2p, label)
if max_seq_len < len(phonemes_with_punc):
max_seq_len = len(phonemes_with_punc)
token_seq = (
self.input_encoder.encode_sequence_torch(phonemes_with_punc)
.int()
.to(self.device)
)
phoneme_labels.append(token_seq)
tokens_padded = torch.LongTensor(len(texts), max_seq_len).to(
self.device
)
tokens_padded.zero_()
for seq_idx, seq in enumerate(phoneme_labels):
tokens_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
def _g2p_keep_punctuations(self, g2p_model, text):
"""do grapheme to phoneme and keep the punctuations between the words"""
# find the words where a "-" or "'" or "." or ":" appears in the middle
special_words = re.findall(r"\w+[-':\.][-':\.\w]*\w+", text)
# remove intra-word punctuations ("-':."), this does not change the output of speechbrain g2p
for special_word in special_words:
rmp = special_word.replace("-", "")
rmp = rmp.replace("'", "")
rmp = rmp.replace(":", "")
rmp = rmp.replace(".", "")
text = text.replace(special_word, rmp)
# keep inter-word punctuations
all_ = re.findall(r"[\w]+|[-!'(),.:;? ]", text)
try:
phonemes = g2p_model(text)
except RuntimeError:
logger.info(f"error with text: {text}")
quit()
word_phonemes = "-".join(phonemes).split(" ")
phonemes_with_punc = []
count = 0
try:
# if the g2p model splits the words correctly
for i in all_:
if i not in "-!'(),.:;? ":
phonemes_with_punc.extend(word_phonemes[count].split("-"))
count += 1
else:
phonemes_with_punc.append(i)
except IndexError:
# sometimes the g2p model cannot split the words correctly
logger.warning(
f"Do g2p word by word because of unexpected outputs from g2p for text: {text}"
)
for i in all_:
if i not in "-!'(),.:;? ":
p = g2p_model.g2p(i)
p_without_space = [i for i in p if i != " "]
phonemes_with_punc.extend(p_without_space)
else:
phonemes_with_punc.append(i)
while "" in phonemes_with_punc:
phonemes_with_punc.remove("")
return phonemes_with_punc
[docs]
def encode_phoneme(
self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Computes mel-spectrogram for a list of phoneme sequences
Arguments
---------
phonemes: List[List[str]]
phonemes to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
all_tokens = []
max_seq_len = -1
for phoneme in phonemes:
token_seq = (
self.input_encoder.encode_sequence_torch(phoneme)
.int()
.to(self.device)
)
if max_seq_len < token_seq.shape[-1]:
max_seq_len = token_seq.shape[-1]
all_tokens.append(token_seq)
tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
self.device
)
tokens_padded.zero_()
for seq_idx, seq in enumerate(all_tokens):
tokens_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
[docs]
def encode_batch(
self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
tokens_padded : torch.Tensor
A sequence of encoded phonemes to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
post_mel_outputs : torch.Tensor
durations : torch.Tensor
pitch : torch.Tensor
energy : torch.Tensor
"""
with torch.no_grad():
(
_,
post_mel_outputs,
durations,
pitch,
_,
energy,
_,
_,
_,
_,
_,
_,
) = self.hparams.model(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
# Transposes to make in compliant with HiFI GAN expected format
post_mel_outputs = post_mel_outputs.transpose(-1, 1)
return post_mel_outputs, durations, pitch, energy
[docs]
def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
text : str
A text to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
Encoded text
"""
return self.encode_text(
[text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
)