Source code for speechbrain.lobes.models.MSTacotron2

"""
Neural network modules for the Zero-Shot Multi-Speaker Tacotron2 end-to-end neural
Text-to-Speech (TTS) model

Authors
* Georges Abous-Rjeili 2021
* Artem Ploujnikov 2021
* Pradnya Kandarkar 2023
"""

# This code uses a significant portion of the NVidia implementation, even though it
# has been modified and enhanced

# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py
# *****************************************************************************
#  Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#      * Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#      * Redistributions in binary form must reproduce the above copyright
#        notice, this list of conditions and the following disclaimer in the
#        documentation and/or other materials provided with the distribution.
#      * Neither the name of the NVIDIA CORPORATION nor the
#        names of its contributors may be used to endorse or promote products
#        derived from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************

from math import sqrt
from speechbrain.nnet.loss.guidedattn_loss import GuidedAttentionLoss
import torch
from torch import nn
from torch.nn import functional as F
from collections import namedtuple
import pickle
from speechbrain.lobes.models.Tacotron2 import (
    LinearNorm,
    Postnet,
    Encoder,
    Decoder,
    get_mask_from_lengths,
)


[docs] class Tacotron2(nn.Module): """The Tactron2 text-to-speech model, based on the NVIDIA implementation. This class is the main entry point for the model, which is responsible for instantiating all submodules, which, in turn, manage the individual neural network layers Simplified STRUCTURE: phoneme input->token embedding ->encoder -> (encoder output + speaker embedding) ->attention \ ->decoder(+prenet) -> postnet ->output prenet(input is decoder previous time step) output is input to decoder concatenanted with the attention output Arguments --------- spk_emb_size: int Speaker embedding size mask_padding: bool whether or not to mask pad-outputs of tacotron #mel generation parameter in data io n_mel_channels: int number of mel channels for constructing spectrogram #symbols n_symbols: int=128 number of accepted char symbols defined in textToSequence symbols_embedding_dim: int number of embeding dimension for symbols fed to nn.Embedding # Encoder parameters encoder_kernel_size: int size of kernel processing the embeddings encoder_n_convolutions: int number of convolution layers in encoder encoder_embedding_dim: int number of kernels in encoder, this is also the dimension of the bidirectional LSTM in the encoder # Attention parameters attention_rnn_dim: int input dimension attention_dim: int number of hidden represetation in attention # Location Layer parameters attention_location_n_filters: int number of 1-D convulation filters in attention attention_location_kernel_size: int length of the 1-D convolution filters # Decoder parameters n_frames_per_step: int=1 only 1 generated mel-frame per step is supported for the decoder as of now. decoder_rnn_dim: int number of 2 unidirectionnal stacked LSTM units prenet_dim: int dimension of linear prenet layers max_decoder_steps: int maximum number of steps/frames the decoder generates before stopping p_attention_dropout: float attention drop out probability p_decoder_dropout: float decoder drop out probability gate_threshold: int cut off level any output probabilty above that is considered complete and stops genration so we have variable length outputs decoder_no_early_stopping: bool determines early stopping of decoder along with gate_threshold . The logical inverse of this is fed to the decoder #Mel-post processing network parameters postnet_embedding_dim: int number os postnet dfilters postnet_kernel_size: int 1d size of posnet kernel postnet_n_convolutions: int number of convolution layers in postnet Example ------- >>> import torch >>> _ = torch.manual_seed(213312) >>> from speechbrain.lobes.models.Tacotron2 import Tacotron2 >>> model = Tacotron2( ... mask_padding=True, ... n_mel_channels=80, ... n_symbols=148, ... symbols_embedding_dim=512, ... encoder_kernel_size=5, ... encoder_n_convolutions=3, ... encoder_embedding_dim=512, ... attention_rnn_dim=1024, ... attention_dim=128, ... attention_location_n_filters=32, ... attention_location_kernel_size=31, ... n_frames_per_step=1, ... decoder_rnn_dim=1024, ... prenet_dim=256, ... max_decoder_steps=32, ... gate_threshold=0.5, ... p_attention_dropout=0.1, ... p_decoder_dropout=0.1, ... postnet_embedding_dim=512, ... postnet_kernel_size=5, ... postnet_n_convolutions=5, ... decoder_no_early_stopping=False ... ) >>> _ = model.eval() >>> inputs = torch.tensor([ ... [13, 12, 31, 14, 19], ... [31, 16, 30, 31, 0], ... ]) >>> input_lengths = torch.tensor([5, 4]) >>> outputs, output_lengths, alignments = model.infer(inputs, input_lengths) >>> outputs.shape, output_lengths.shape, alignments.shape (torch.Size([2, 80, 1]), torch.Size([2]), torch.Size([2, 1, 5])) """ def __init__( self, spk_emb_size, mask_padding=True, n_mel_channels=80, n_symbols=148, symbols_embedding_dim=512, encoder_kernel_size=5, encoder_n_convolutions=3, encoder_embedding_dim=512, attention_rnn_dim=1024, attention_dim=128, attention_location_n_filters=32, attention_location_kernel_size=31, n_frames_per_step=1, decoder_rnn_dim=1024, prenet_dim=256, max_decoder_steps=1000, gate_threshold=0.5, p_attention_dropout=0.1, p_decoder_dropout=0.1, postnet_embedding_dim=512, postnet_kernel_size=5, postnet_n_convolutions=5, decoder_no_early_stopping=False, ): super().__init__() self.mask_padding = mask_padding self.n_mel_channels = n_mel_channels self.n_frames_per_step = n_frames_per_step self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim) std = sqrt(2.0 / (n_symbols + symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder( encoder_n_convolutions, encoder_embedding_dim, encoder_kernel_size ) self.decoder = Decoder( n_mel_channels, n_frames_per_step, encoder_embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, attention_rnn_dim, decoder_rnn_dim, prenet_dim, max_decoder_steps, gate_threshold, p_attention_dropout, p_decoder_dropout, not decoder_no_early_stopping, ) self.postnet = Postnet( n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions, ) # Additions for Zero-Shot Multi-Speaker TTS # FiLM (Feature-wise Linear Modulation) layers for injecting the speaker embeddings into the TTS pipeline self.ms_film_hidden_size = int( (spk_emb_size + encoder_embedding_dim) / 2 ) self.ms_film_hidden = LinearNorm(spk_emb_size, self.ms_film_hidden_size) self.ms_film_h = LinearNorm( self.ms_film_hidden_size, encoder_embedding_dim ) self.ms_film_g = LinearNorm( self.ms_film_hidden_size, encoder_embedding_dim )
[docs] def parse_output(self, outputs, output_lengths, alignments_dim=None): """ Masks the padded part of output Arguments --------- outputs: list a list of tensors - raw outputs output_lengths: torch.Tensor a tensor representing the lengths of all outputs alignments_dim: int the desired dimension of the alignments along the last axis Optional but needed for data-parallel training Returns ------- result: tuple a (mel_outputs, mel_outputs_postnet, gate_outputs, alignments) tuple with the original outputs - with the mask applied """ mel_outputs, mel_outputs_postnet, gate_outputs, alignments = outputs if self.mask_padding and output_lengths is not None: mask = get_mask_from_lengths( output_lengths, max_len=mel_outputs.size(-1) ) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) mel_outputs.clone().masked_fill_(mask, 0.0) mel_outputs_postnet.masked_fill_(mask, 0.0) gate_outputs.masked_fill_(mask[:, 0, :], 1e3) # gate energies if alignments_dim is not None: alignments = F.pad( alignments, (0, alignments_dim - alignments.size(-1)) ) return ( mel_outputs, mel_outputs_postnet, gate_outputs, alignments, output_lengths, )
[docs] def forward(self, inputs, spk_embs, alignments_dim=None): """Decoder forward pass for training Arguments --------- inputs: tuple batch object spk_embs: torch.Tensor Speaker embeddings corresponding to the inputs alignments_dim: int the desired dimension of the alignments along the last axis Optional but needed for data-parallel training Returns --------- mel_outputs: torch.Tensor mel outputs from the decoder mel_outputs_postnet: torch.Tensor mel outputs from postnet gate_outputs: torch.Tensor gate outputs from the decoder alignments: torch.Tensor sequence of attention weights from the decoder output_legnths: torch.Tensor length of the output without padding """ inputs, input_lengths, targets, max_len, output_lengths = inputs input_lengths, output_lengths = input_lengths.data, output_lengths.data embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, input_lengths) # Inject speaker embeddings into the encoder output spk_embs_shared = F.relu(self.ms_film_hidden(spk_embs)) spk_embs_h = self.ms_film_h(spk_embs_shared) spk_embs_h = torch.unsqueeze(spk_embs_h, 1).repeat( 1, encoder_outputs.shape[1], 1 ) encoder_outputs = encoder_outputs * spk_embs_h spk_embs_g = self.ms_film_g(spk_embs_shared) spk_embs_g = torch.unsqueeze(spk_embs_g, 1).repeat( 1, encoder_outputs.shape[1], 1 ) encoder_outputs = encoder_outputs + spk_embs_g # Pass the encoder output combined with speaker embeddings to the next layers mel_outputs, gate_outputs, alignments = self.decoder( encoder_outputs, targets, memory_lengths=input_lengths ) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet return self.parse_output( [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], output_lengths, alignments_dim, )
[docs] def infer(self, inputs, spk_embs, input_lengths): """Produces outputs Arguments --------- inputs: torch.tensor text or phonemes converted spk_embs: torch.Tensor Speaker embeddings corresponding to the inputs input_lengths: torch.tensor the lengths of input parameters Returns ------- mel_outputs_postnet: torch.Tensor final mel output of tacotron 2 mel_lengths: torch.Tensor length of mels alignments: torch.Tensor sequence of attention weights """ embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder.infer(embedded_inputs, input_lengths) # Inject speaker embeddings into the encoder output spk_embs_shared = F.relu(self.ms_film_hidden(spk_embs)) spk_embs_h = self.ms_film_h(spk_embs_shared) spk_embs_h = torch.unsqueeze(spk_embs_h, 1).repeat( 1, encoder_outputs.shape[1], 1 ) encoder_outputs = encoder_outputs * spk_embs_h spk_embs_g = self.ms_film_g(spk_embs_shared) spk_embs_g = torch.unsqueeze(spk_embs_g, 1).repeat( 1, encoder_outputs.shape[1], 1 ) encoder_outputs = encoder_outputs + spk_embs_g # Pass the encoder output combined with speaker embeddings to the next layers mel_outputs, gate_outputs, alignments, mel_lengths = self.decoder.infer( encoder_outputs, input_lengths ) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet BS = mel_outputs_postnet.size(0) alignments = alignments.unfold(1, BS, BS).transpose(0, 2) return mel_outputs_postnet, mel_lengths, alignments
LossStats = namedtuple( "TacotronLoss", "loss mel_loss spk_emb_loss gate_loss attn_loss attn_weight" )
[docs] class Loss(nn.Module): """The Tacotron loss implementation The loss consists of an MSE loss on the spectrogram, a BCE gate loss and a guided attention loss (if enabled) that attempts to make the attention matrix diagonal The output of the moduel is a LossStats tuple, which includes both the total loss Arguments --------- guided_attention_sigma: float The guided attention sigma factor, controling the "width" of the mask gate_loss_weight: float The constant by which the gate loss will be multiplied mel_loss_weight: float The constant by which the mel loss will be multiplied spk_emb_loss_weight: float The constant by which the speaker embedding loss will be multiplied - placeholder for future work spk_emb_loss_type: str Type of the speaker embedding loss - placeholder for future work guided_attention_weight: float The weight for the guided attention guided_attention_scheduler: callable The scheduler class for the guided attention loss guided_attention_hard_stop: int The number of epochs after which guided attention will be compeltely turned off Example: >>> import torch >>> _ = torch.manual_seed(42) >>> from speechbrain.lobes.models.MSTacotron2 import Loss >>> loss = Loss(guided_attention_sigma=0.2) >>> mel_target = torch.randn(2, 80, 861) >>> gate_target = torch.randn(1722, 1) >>> mel_out = torch.randn(2, 80, 861) >>> mel_out_postnet = torch.randn(2, 80, 861) >>> gate_out = torch.randn(2, 861) >>> alignments = torch.randn(2, 861, 173) >>> pred_mel_lens = torch.randn(2) >>> targets = mel_target, gate_target >>> model_outputs = mel_out, mel_out_postnet, gate_out, alignments, pred_mel_lens >>> input_lengths = torch.tensor([173, 91]) >>> target_lengths = torch.tensor([861, 438]) >>> spk_embs = None >>> loss(model_outputs, targets, input_lengths, target_lengths, spk_embs, 1) TacotronLoss(loss=tensor([4.8566]), mel_loss=tensor(4.0097), spk_emb_loss=tensor([0.]), gate_loss=tensor(0.8460), attn_loss=tensor(0.0010), attn_weight=tensor(1.)) """ def __init__( self, guided_attention_sigma=None, gate_loss_weight=1.0, mel_loss_weight=1.0, spk_emb_loss_weight=1.0, spk_emb_loss_type=None, guided_attention_weight=1.0, guided_attention_scheduler=None, guided_attention_hard_stop=None, ): super().__init__() if guided_attention_weight == 0: guided_attention_weight = None self.guided_attention_weight = guided_attention_weight self.gate_loss_weight = gate_loss_weight self.mel_loss_weight = mel_loss_weight self.spk_emb_loss_weight = spk_emb_loss_weight self.spk_emb_loss_type = spk_emb_loss_type self.mse_loss = nn.MSELoss() self.bce_loss = nn.BCEWithLogitsLoss() self.guided_attention_loss = GuidedAttentionLoss( sigma=guided_attention_sigma ) self.cos_sim = nn.CosineSimilarity() self.triplet_loss = torch.nn.TripletMarginWithDistanceLoss( distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y) ) self.cos_emb_loss = nn.CosineEmbeddingLoss() self.guided_attention_scheduler = guided_attention_scheduler self.guided_attention_hard_stop = guided_attention_hard_stop
[docs] def forward( self, model_output, targets, input_lengths, target_lengths, spk_embs, epoch, ): """Computes the loss Arguments --------- model_output: tuple the output of the model's forward(): (mel_outputs, mel_outputs_postnet, gate_outputs, alignments) targets: tuple the targets input_lengths: torch.Tensor a (batch, length) tensor of input lengths target_lengths: torch.Tensor a (batch, length) tensor of target (spectrogram) lengths spk_embs: torch.Tensor Speaker embedding input for the loss computation - placeholder for future work epoch: int the current epoch number (used for the scheduling of the guided attention loss) A StepScheduler is typically used Returns ------- result: LossStats the total loss - and individual losses (mel and gate) """ mel_target, gate_target = targets[0], targets[1] mel_target.requires_grad = False gate_target.requires_grad = False gate_target = gate_target.view(-1, 1) ( mel_out, mel_out_postnet, gate_out, alignments, pred_mel_lens, ) = model_output gate_out = gate_out.view(-1, 1) mel_loss = self.mse_loss(mel_out, mel_target) + self.mse_loss( mel_out_postnet, mel_target ) mel_loss = self.mel_loss_weight * mel_loss gate_loss = self.gate_loss_weight * self.bce_loss(gate_out, gate_target) attn_loss, attn_weight = self.get_attention_loss( alignments, input_lengths, target_lengths, epoch ) # Speaker embedding loss placeholder - for future work spk_emb_loss = torch.Tensor([0]).to(mel_loss.device) if self.spk_emb_loss_type == "scl_loss": target_spk_embs, preds_spk_embs = spk_embs cos_sim_scores = self.cos_sim(preds_spk_embs, target_spk_embs) spk_emb_loss = -torch.div( torch.sum(cos_sim_scores), len(cos_sim_scores) ) if self.spk_emb_loss_type == "cos_emb_loss": target_spk_embs, preds_spk_embs = spk_embs spk_emb_loss = self.cos_emb_loss( target_spk_embs, preds_spk_embs, torch.ones(len(target_spk_embs)).to(target_spk_embs.device), ) if self.spk_emb_loss_type == "triplet_loss": anchor_spk_embs, pos_spk_embs, neg_spk_embs = spk_embs if anchor_spk_embs is not None: spk_emb_loss = self.triplet_loss( anchor_spk_embs, pos_spk_embs, neg_spk_embs ) spk_emb_loss = self.spk_emb_loss_weight * spk_emb_loss total_loss = mel_loss + spk_emb_loss + gate_loss + attn_loss return LossStats( total_loss, mel_loss, spk_emb_loss, gate_loss, attn_loss, attn_weight, )
[docs] def get_attention_loss( self, alignments, input_lengths, target_lengths, epoch ): """Computes the attention loss Arguments --------- alignments: torch.Tensor the aligment matrix from the model input_lengths: torch.Tensor a (batch, length) tensor of input lengths target_lengths: torch.Tensor a (batch, length) tensor of target (spectrogram) lengths epoch: int the current epoch number (used for the scheduling of the guided attention loss) A StepScheduler is typically used Returns ------- attn_loss: torch.Tensor the attention loss value """ zero_tensor = torch.tensor(0.0, device=alignments.device) if ( self.guided_attention_weight is None or self.guided_attention_weight == 0 ): attn_weight, attn_loss = zero_tensor, zero_tensor else: hard_stop_reached = ( self.guided_attention_hard_stop is not None and epoch > self.guided_attention_hard_stop ) if hard_stop_reached: attn_weight, attn_loss = zero_tensor, zero_tensor else: attn_weight = self.guided_attention_weight if self.guided_attention_scheduler is not None: _, attn_weight = self.guided_attention_scheduler(epoch) attn_weight = torch.tensor(attn_weight, device=alignments.device) attn_loss = attn_weight * self.guided_attention_loss( alignments, input_lengths, target_lengths ) return attn_loss, attn_weight
[docs] class TextMelCollate: """ Zero-pads model inputs and targets based on number of frames per step Arguments --------- speaker_embeddings_pickle : str Path to the file containing speaker embeddings n_frames_per_step: int The number of output frames per step Returns ------- result: tuple a tuple inputs/targets ( text_padded, input_lengths, mel_padded, gate_padded, output_lengths, len_x, labels, wavs, spk_embs, spk_ids ) """ def __init__( self, speaker_embeddings_pickle, n_frames_per_step=1, ): self.n_frames_per_step = n_frames_per_step self.speaker_embeddings_pickle = speaker_embeddings_pickle # TODO: Make this more intuitive, use the pipeline
[docs] def __call__(self, batch): """Collate's training batch from normalized text and mel-spectrogram Arguments --------- batch: list [text_normalized, mel_normalized] """ # TODO: Remove for loops and this dirty hack raw_batch = list(batch) for i in range( len(batch) ): # the pipline return a dictionary wiht one elemnent batch[i] = batch[i]["mel_text_pair"] # Right zero-pad all one-hot text sequences to max input length input_lengths, ids_sorted_decreasing = torch.sort( torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True ) max_input_len = input_lengths[0] text_padded = torch.LongTensor(len(batch), max_input_len) text_padded.zero_() for i in range(len(ids_sorted_decreasing)): text = batch[ids_sorted_decreasing[i]][0] text_padded[i, : text.size(0)] = text # Right zero-pad mel-spec num_mels = batch[0][1].size(0) max_target_len = max([x[1].size(1) for x in batch]) if max_target_len % self.n_frames_per_step != 0: max_target_len += ( self.n_frames_per_step - max_target_len % self.n_frames_per_step ) assert max_target_len % self.n_frames_per_step == 0 # include mel padded and gate padded mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) mel_padded.zero_() gate_padded = torch.FloatTensor(len(batch), max_target_len) gate_padded.zero_() output_lengths = torch.LongTensor(len(batch)) labels, wavs, spk_embs_list, spk_ids = [], [], [], [] with open( self.speaker_embeddings_pickle, "rb" ) as speaker_embeddings_file: speaker_embeddings = pickle.load(speaker_embeddings_file) for i in range(len(ids_sorted_decreasing)): idx = ids_sorted_decreasing[i] mel = batch[idx][1] mel_padded[i, :, : mel.size(1)] = mel gate_padded[i, mel.size(1) - 1 :] = 1 output_lengths[i] = mel.size(1) labels.append(raw_batch[idx]["label"]) wavs.append(raw_batch[idx]["wav"]) spk_emb = speaker_embeddings[raw_batch[idx]["uttid"]] spk_embs_list.append(spk_emb) spk_ids.append(raw_batch[idx]["uttid"].split("_")[0]) spk_embs = torch.stack(spk_embs_list) # count number of items - characters in text len_x = [x[2] for x in batch] len_x = torch.Tensor(len_x) return ( text_padded, input_lengths, mel_padded, gate_padded, output_lengths, len_x, labels, wavs, spk_embs, spk_ids, )