""" Specifies the inference interfaces for Text-To-Speech (TTS) modules.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import re
import logging
import torch
import torchaudio
import random
import speechbrain
from speechbrain.utils.fetching import fetch
from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.text_to_sequence import text_to_sequence
from speechbrain.inference.text import GraphemeToPhoneme
from speechbrain.inference.encoders import MelSpectrogramEncoder
from speechbrain.inference.classifiers import EncoderClassifier
logger = logging.getLogger(__name__)
[docs]
class Tacotron2(Pretrained):
"""
A ready-to-use wrapper for Tacotron2 (text -> mel_spec).
Arguments
---------
hparams
Hyperparameters (from HyperPyYAML)
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts)
>>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
>>> items = [
... "A quick brown fox jumped over the lazy dog",
... "How much wood would a woodchuck chuck?",
... "Never odd or even"
... ]
>>> mel_outputs, mel_lengths, alignments = tacotron2.encode_batch(items)
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Intialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder)
>>> # Running the TTS
>>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb")
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_output)
"""
HPARAMS_NEEDED = ["model", "text_to_sequence"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_cleaners = getattr(
self.hparams, "text_cleaners", ["english_cleaners"]
)
self.infer = self.hparams.model.infer
[docs]
def text_to_seq(self, txt):
"""Encodes raw text into a tensor with a customer text-to-sequence function"""
sequence = self.hparams.text_to_sequence(txt, self.text_cleaners)
return sequence, len(sequence)
[docs]
def encode_batch(self, texts):
"""Computes mel-spectrogram for a list of texts
Texts must be sorted in decreasing order on their lengths
Arguments
---------
texts: List[str]
texts to be encoded into spectrogram
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
with torch.no_grad():
inputs = [
{
"text_sequences": torch.tensor(
self.text_to_seq(item)[0], device=self.device
)
}
for item in texts
]
inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
lens = [self.text_to_seq(item)[1] for item in texts]
assert lens == sorted(
lens, reverse=True
), "input lengths must be sorted in decreasing order"
input_lengths = torch.tensor(lens, device=self.device)
mel_outputs_postnet, mel_lengths, alignments = self.infer(
inputs.text_sequences.data, input_lengths
)
return mel_outputs_postnet, mel_lengths, alignments
[docs]
def encode_text(self, text):
"""Runs inference for a single text str"""
return self.encode_batch([text])
[docs]
def forward(self, texts):
"Encodes the input texts."
return self.encode_batch(texts)
[docs]
class MSTacotron2(Pretrained):
"""
A ready-to-use wrapper for Zero-Shot Multi-Speaker Tacotron2.
For voice cloning: (text, reference_audio) -> (mel_spec).
For generating a random speaker voice: (text) -> (mel_spec).
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> mstacotron2 = MSTacotron2.from_hparams(source="speechbrain/tts-mstacotron2-libritts", savedir=tmpdir_tts) # doctest: +SKIP
>>> # Sample rate of the reference audio must be greater or equal to the sample rate of the speaker embedding model
>>> reference_audio_path = "tests/samples/single-mic/example1.wav"
>>> input_text = "Mary had a little lamb."
>>> mel_output, mel_length, alignment = mstacotron2.clone_voice(input_text, reference_audio_path) # doctest: +SKIP
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Intialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-libritts-22050Hz", savedir=tmpdir_vocoder) # doctest: +SKIP
>>> # Running the TTS
>>> mel_output, mel_length, alignment = mstacotron2.clone_voice(input_text, reference_audio_path) # doctest: +SKIP
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_output) # doctest: +SKIP
>>> # For generating a random speaker voice, use the following
>>> mel_output, mel_length, alignment = mstacotron2.generate_random_voice(input_text) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["model"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_cleaners = ["english_cleaners"]
self.infer = self.hparams.model.infer
self.custom_mel_spec_encoder = self.hparams.custom_mel_spec_encoder
self.g2p = GraphemeToPhoneme.from_hparams(
self.hparams.g2p, run_opts={"device": self.device}
)
self.spk_emb_encoder = None
if self.custom_mel_spec_encoder:
self.spk_emb_encoder = MelSpectrogramEncoder.from_hparams(
source=self.hparams.spk_emb_encoder,
run_opts={"device": self.device},
)
else:
self.spk_emb_encoder = EncoderClassifier.from_hparams(
source=self.hparams.spk_emb_encoder,
run_opts={"device": self.device},
)
def __text_to_seq(self, txt):
"""Encodes raw text into a tensor with a customer text-to-equence fuction
"""
sequence = text_to_sequence(txt, self.text_cleaners)
return sequence, len(sequence)
[docs]
def clone_voice(self, texts, audio_path):
"""
Generates mel-spectrogram using input text and reference audio
Arguments
---------
texts : str or list
Input text
audio_path : str
Reference audio
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
# Loads audio
ref_signal, signal_sr = torchaudio.load(audio_path)
# Resamples the audio if required
if signal_sr != self.hparams.spk_emb_sample_rate:
ref_signal = torchaudio.functional.resample(
ref_signal, signal_sr, self.hparams.spk_emb_sample_rate
)
ref_signal = ref_signal.to(self.device)
# Computes speaker embedding
if self.custom_mel_spec_encoder:
spk_emb = self.spk_emb_encoder.encode_waveform(ref_signal)
else:
spk_emb = self.spk_emb_encoder.encode_batch(ref_signal)
spk_emb = spk_emb.squeeze(0)
# Converts input texts into the corresponding phoneme sequences
if isinstance(texts, str):
texts = [texts]
phoneme_seqs = self.g2p(texts)
for i in range(len(phoneme_seqs)):
phoneme_seqs[i] = " ".join(phoneme_seqs[i])
phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}"
# Repeats the speaker embedding to match the number of input texts
spk_embs = spk_emb.repeat(len(texts), 1)
# Calls __encode_batch to generate the mel-spectrograms
return self.__encode_batch(phoneme_seqs, spk_embs)
[docs]
def generate_random_voice(self, texts):
"""
Generates mel-spectrogram using input text and a random speaker voice
Arguments
---------
texts : str or list
Input text
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
spk_emb = self.__sample_random_speaker().float()
spk_emb = spk_emb.to(self.device)
# Converts input texts into the corresponding phoneme sequences
if isinstance(texts, str):
texts = [texts]
phoneme_seqs = self.g2p(texts)
for i in range(len(phoneme_seqs)):
phoneme_seqs[i] = " ".join(phoneme_seqs[i])
phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}"
# Repeats the speaker embedding to match the number of input texts
spk_embs = spk_emb.repeat(len(texts), 1)
# Calls __encode_batch to generate the mel-spectrograms
return self.__encode_batch(phoneme_seqs, spk_embs)
def __encode_batch(self, texts, spk_embs):
"""Computes mel-spectrograms for a list of texts
Texts are sorted in decreasing order on their lengths
Arguments
---------
texts: List[str]
texts to be encoded into spectrogram
spk_embs: torch.Tensor
speaker embeddings
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
with torch.no_grad():
inputs = [
{
"text_sequences": torch.tensor(
self.__text_to_seq(item)[0], device=self.device
)
}
for item in texts
]
inputs = sorted(
inputs,
key=lambda x: x["text_sequences"].size()[0],
reverse=True,
)
lens = [entry["text_sequences"].size()[0] for entry in inputs]
inputs = speechbrain.dataio.batch.PaddedBatch(inputs)
assert lens == sorted(
lens, reverse=True
), "ipnut lengths must be sorted in decreasing order"
input_lengths = torch.tensor(lens, device=self.device)
mel_outputs_postnet, mel_lengths, alignments = self.infer(
inputs.text_sequences.data, spk_embs, input_lengths
)
return mel_outputs_postnet, mel_lengths, alignments
def __sample_random_speaker(self):
"""Samples a random speaker embedding from a pretrained GMM
Returns
-------
x: torch.Tensor
A randomly sampled speaker embedding
"""
# Fetches and Loads GMM trained on speaker embeddings
speaker_gmm_local_path = fetch(
filename=self.hparams.random_speaker_sampler,
source=self.hparams.random_speaker_sampler_source,
savedir=self.hparams.pretrainer.collect_in,
)
random_speaker_gmm = torch.load(speaker_gmm_local_path)
gmm_n_components = random_speaker_gmm["gmm_n_components"]
gmm_means = random_speaker_gmm["gmm_means"]
gmm_covariances = random_speaker_gmm["gmm_covariances"]
# Randomly selects a speaker
counts = torch.zeros(gmm_n_components)
counts[random.randint(0, gmm_n_components - 1)] = 1
x = torch.empty(0, device=counts.device)
# Samples an embedding for the speaker
for k in torch.arange(gmm_n_components)[counts > 0]:
# Considers full covariance type
d_k = torch.distributions.multivariate_normal.MultivariateNormal(
gmm_means[k], gmm_covariances[k]
)
x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))])
x = torch.cat((x, x_k), dim=0)
return x
[docs]
class FastSpeech2(Pretrained):
"""
A ready-to-use wrapper for Fastspeech2 (text -> mel_spec).
Arguments
---------
hparams
Hyperparameters (from HyperPyYAML)
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir=tmpdir_tts) # doctest: +SKIP
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> items = [
... "A quick brown fox jumped over the lazy dog",
... "How much wood would a woodchuck chuck?",
... "Never odd or even"
... ]
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items) # doctest: +SKIP
>>>
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Intialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder) # doctest: +SKIP
>>> # Running the TTS
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["spn_predictor", "model", "input_encoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
lexicon = self.hparams.lexicon
lexicon = ["@@"] + lexicon
self.input_encoder = self.hparams.input_encoder
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
self.input_encoder.add_unk()
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
self.spn_token_encoded = (
self.input_encoder.encode_sequence_torch(["spn"]).int().item()
)
[docs]
def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Computes mel-spectrogram for a list of texts
Arguments
---------
texts: List[str]
texts to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
# Preprocessing required at the inference time for the input text
# "label" below contains input text
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
# "last_phonemes_combined" is used to indicate whether the index position is for a last phoneme of a word
# "punc_positions" is used to add back the silence for punctuations
phoneme_labels = list()
last_phonemes_combined = list()
punc_positions = list()
for label in texts:
phoneme_label = list()
last_phonemes = list()
punc_position = list()
words = label.split()
words = [word.strip() for word in words]
words_phonemes = self.g2p(words)
for i in range(len(words_phonemes)):
words_phonemes_seq = words_phonemes[i]
for phoneme in words_phonemes_seq:
if not phoneme.isspace():
phoneme_label.append(phoneme)
last_phonemes.append(0)
punc_position.append(0)
last_phonemes[-1] = 1
if words[i][-1] in ":;-,.!?":
punc_position[-1] = 1
phoneme_labels.append(phoneme_label)
last_phonemes_combined.append(last_phonemes)
punc_positions.append(punc_position)
# Inserts silent phonemes in the input phoneme sequence
all_tokens_with_spn = list()
max_seq_len = -1
for i in range(len(phoneme_labels)):
phoneme_label = phoneme_labels[i]
token_seq = (
self.input_encoder.encode_sequence_torch(phoneme_label)
.int()
.to(self.device)
)
last_phonemes = torch.LongTensor(last_phonemes_combined[i]).to(
self.device
)
# Runs the silent phoneme predictor
spn_preds = (
self.hparams.modules["spn_predictor"]
.infer(token_seq.unsqueeze(0), last_phonemes.unsqueeze(0))
.int()
)
spn_to_add = torch.nonzero(spn_preds).reshape(-1).tolist()
for j in range(len(punc_positions[i])):
if punc_positions[i][j] == 1:
spn_to_add.append(j)
tokens_with_spn = list()
for token_idx in range(token_seq.shape[0]):
tokens_with_spn.append(token_seq[token_idx].item())
if token_idx in spn_to_add:
tokens_with_spn.append(self.spn_token_encoded)
tokens_with_spn = torch.LongTensor(tokens_with_spn).to(self.device)
all_tokens_with_spn.append(tokens_with_spn)
if max_seq_len < tokens_with_spn.shape[-1]:
max_seq_len = tokens_with_spn.shape[-1]
# "tokens_with_spn_tensor" holds the input phoneme sequence with silent phonemes
tokens_with_spn_tensor_padded = torch.LongTensor(
len(texts), max_seq_len
).to(self.device)
tokens_with_spn_tensor_padded.zero_()
for seq_idx, seq in enumerate(all_tokens_with_spn):
tokens_with_spn_tensor_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_with_spn_tensor_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
[docs]
def encode_phoneme(
self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Computes mel-spectrogram for a list of phoneme sequences
Arguments
---------
phonemes: List[List[str]]
phonemes to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
all_tokens = []
max_seq_len = -1
for phoneme in phonemes:
token_seq = (
self.input_encoder.encode_sequence_torch(phoneme)
.int()
.to(self.device)
)
if max_seq_len < token_seq.shape[-1]:
max_seq_len = token_seq.shape[-1]
all_tokens.append(token_seq)
tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
self.device
)
tokens_padded.zero_()
for seq_idx, seq in enumerate(all_tokens):
tokens_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
[docs]
def encode_batch(
self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
tokens_padded : torch.Tensor
A sequence of encoded phonemes to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
"""
with torch.no_grad():
(
_,
post_mel_outputs,
durations,
pitch,
_,
energy,
_,
_,
) = self.hparams.model(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
# Transposes to make in compliant with HiFI GAN expected format
post_mel_outputs = post_mel_outputs.transpose(-1, 1)
return post_mel_outputs, durations, pitch, energy
[docs]
def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
text : str
A text to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
"""
return self.encode_text(
[text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
)
[docs]
class FastSpeech2InternalAlignment(Pretrained):
"""
A ready-to-use wrapper for Fastspeech2 with internal alignment(text -> mel_spec).
Arguments
---------
hparams
Hyperparameters (from HyperPyYAML)
Example
-------
>>> tmpdir_tts = getfixture('tmpdir') / "tts"
>>> fastspeech2 = FastSpeech2InternalAlignment.from_hparams(source="speechbrain/tts-fastspeech2-internal-alignment-ljspeech", savedir=tmpdir_tts) # doctest: +SKIP
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> items = [
... "A quick brown fox jumped over the lazy dog",
... "How much wood would a woodchuck chuck?",
... "Never odd or even"
... ]
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(items) # doctest: +SKIP
>>> # One can combine the TTS model with a vocoder (that generates the final waveform)
>>> # Intialize the Vocoder (HiFIGAN)
>>> tmpdir_vocoder = getfixture('tmpdir') / "vocoder"
>>> from speechbrain.inference.vocoders import HIFIGAN
>>> hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder) # doctest: +SKIP
>>> # Running the TTS
>>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text(["Mary had a little lamb."]) # doctest: +SKIP
>>> # Running Vocoder (spectrogram-to-waveform)
>>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["model", "input_encoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
lexicon = self.hparams.lexicon
lexicon = ["@@"] + lexicon
self.input_encoder = self.hparams.input_encoder
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
self.input_encoder.add_unk()
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
[docs]
def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Computes mel-spectrogram for a list of texts
Arguments
---------
texts: List[str]
texts to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
# Preprocessing required at the inference time for the input text
# "label" below contains input text
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
phoneme_labels = list()
max_seq_len = -1
for label in texts:
phonemes_with_punc = self._g2p_keep_punctuations(self.g2p, label)
if max_seq_len < len(phonemes_with_punc):
max_seq_len = len(phonemes_with_punc)
token_seq = (
self.input_encoder.encode_sequence_torch(phonemes_with_punc)
.int()
.to(self.device)
)
phoneme_labels.append(token_seq)
tokens_padded = torch.LongTensor(len(texts), max_seq_len).to(
self.device
)
tokens_padded.zero_()
for seq_idx, seq in enumerate(phoneme_labels):
tokens_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
def _g2p_keep_punctuations(self, g2p_model, text):
"""do grapheme to phoneme and keep the punctuations between the words"""
# find the words where a "-" or "'" or "." or ":" appears in the middle
special_words = re.findall(r"\w+[-':\.][-':\.\w]*\w+", text)
# remove intra-word punctuations ("-':."), this does not change the output of speechbrain g2p
for special_word in special_words:
rmp = special_word.replace("-", "")
rmp = rmp.replace("'", "")
rmp = rmp.replace(":", "")
rmp = rmp.replace(".", "")
text = text.replace(special_word, rmp)
# keep inter-word punctuations
all_ = re.findall(r"[\w]+|[-!'(),.:;? ]", text)
try:
phonemes = g2p_model(text)
except RuntimeError:
logger.info(f"error with text: {text}")
quit()
word_phonemes = "-".join(phonemes).split(" ")
phonemes_with_punc = []
count = 0
try:
# if the g2p model splits the words correctly
for i in all_:
if i not in "-!'(),.:;? ":
phonemes_with_punc.extend(word_phonemes[count].split("-"))
count += 1
else:
phonemes_with_punc.append(i)
except IndexError:
# sometimes the g2p model cannot split the words correctly
logger.warning(
f"Do g2p word by word because of unexpected ouputs from g2p for text: {text}"
)
for i in all_:
if i not in "-!'(),.:;? ":
p = g2p_model.g2p(i)
p_without_space = [i for i in p if i != " "]
phonemes_with_punc.extend(p_without_space)
else:
phonemes_with_punc.append(i)
while "" in phonemes_with_punc:
phonemes_with_punc.remove("")
return phonemes_with_punc
[docs]
def encode_phoneme(
self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Computes mel-spectrogram for a list of phoneme sequences
Arguments
---------
phonemes: List[List[str]]
phonemes to be converted to spectrogram
pace: float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
Returns
-------
tensors of output spectrograms, output lengths and alignments
"""
all_tokens = []
max_seq_len = -1
for phoneme in phonemes:
token_seq = (
self.input_encoder.encode_sequence_torch(phoneme)
.int()
.to(self.device)
)
if max_seq_len < token_seq.shape[-1]:
max_seq_len = token_seq.shape[-1]
all_tokens.append(token_seq)
tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to(
self.device
)
tokens_padded.zero_()
for seq_idx, seq in enumerate(all_tokens):
tokens_padded[seq_idx, : len(seq)] = seq
return self.encode_batch(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
[docs]
def encode_batch(
self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0
):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
tokens_padded : torch.Tensor
A sequence of encoded phonemes to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
"""
with torch.no_grad():
(
_,
post_mel_outputs,
durations,
pitch,
_,
energy,
_,
_,
_,
_,
_,
_,
) = self.hparams.model(
tokens_padded,
pace=pace,
pitch_rate=pitch_rate,
energy_rate=energy_rate,
)
# Transposes to make in compliant with HiFI GAN expected format
post_mel_outputs = post_mel_outputs.transpose(-1, 1)
return post_mel_outputs, durations, pitch, energy
[docs]
def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0):
"""Batch inference for a tensor of phoneme sequences
Arguments
---------
text : str
A text to be converted to spectrogram
pace : float
pace for the speech synthesis
pitch_rate : float
scaling factor for phoneme pitches
energy_rate : float
scaling factor for phoneme energies
"""
return self.encode_text(
[text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate
)