"""
Neural network modules for the FastSpeech 2: Fast and High-Quality End-to-End Text to Speech
synthesis model
Authors
* Sathvik Udupa 2022
* Pradnya Kandarkar 2023
* Yingzhi Wang 2023
"""
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.modules.loss import _Loss
from speechbrain.lobes.models.transformer.Transformer import (
PositionalEncoding,
TransformerEncoder,
get_key_padding_mask,
get_mask_from_lengths,
)
from speechbrain.nnet import CNN, linear
from speechbrain.nnet.embedding import Embedding
from speechbrain.nnet.losses import bce_loss
from speechbrain.nnet.normalization import LayerNorm
[docs]
class EncoderPreNet(nn.Module):
"""Embedding layer for tokens
Arguments
---------
n_vocab: int
size of the dictionary of embeddings
blank_id: int
padding index
out_channels: int
the size of each embedding vector
Example
-------
>>> from speechbrain.nnet.embedding import Embedding
>>> from speechbrain.lobes.models.FastSpeech2 import EncoderPreNet
>>> encoder_prenet_layer = EncoderPreNet(n_vocab=40, blank_id=0, out_channels=384)
>>> x = torch.rand(3, 5)
>>> y = encoder_prenet_layer(x)
>>> y.shape
torch.Size([3, 5, 384])
"""
def __init__(self, n_vocab, blank_id, out_channels=512):
super().__init__()
self.token_embedding = Embedding(
num_embeddings=n_vocab,
embedding_dim=out_channels,
blank_id=blank_id,
)
[docs]
def forward(self, x):
"""Computes the forward pass
Arguments
---------
x: torch.Tensor
a (batch, tokens) input tensor
Returns
-------
output: torch.Tensor
the embedding layer output
"""
self.token_embedding = self.token_embedding.to(x.device)
x = self.token_embedding(x)
return x
[docs]
class PostNet(nn.Module):
"""
FastSpeech2 Conv Postnet
Arguments
---------
n_mel_channels: int
input feature dimension for convolution layers
postnet_embedding_dim: int
output feature dimension for convolution layers
postnet_kernel_size: int
postnet convolution kernel size
postnet_n_convolutions: int
number of convolution layers
postnet_dropout: float
dropout probability for postnet
"""
def __init__(
self,
n_mel_channels=80,
postnet_embedding_dim=512,
postnet_kernel_size=5,
postnet_n_convolutions=5,
postnet_dropout=0.5,
):
super(PostNet, self).__init__()
self.conv_pre = CNN.Conv1d(
in_channels=n_mel_channels,
out_channels=postnet_embedding_dim,
kernel_size=postnet_kernel_size,
padding="same",
)
self.convs_intermediate = nn.ModuleList()
for i in range(1, postnet_n_convolutions - 1):
self.convs_intermediate.append(
CNN.Conv1d(
in_channels=postnet_embedding_dim,
out_channels=postnet_embedding_dim,
kernel_size=postnet_kernel_size,
padding="same",
),
)
self.conv_post = CNN.Conv1d(
in_channels=postnet_embedding_dim,
out_channels=n_mel_channels,
kernel_size=postnet_kernel_size,
padding="same",
)
self.tanh = nn.Tanh()
self.ln1 = nn.LayerNorm(postnet_embedding_dim)
self.ln2 = nn.LayerNorm(postnet_embedding_dim)
self.ln3 = nn.LayerNorm(n_mel_channels)
self.dropout1 = nn.Dropout(postnet_dropout)
self.dropout2 = nn.Dropout(postnet_dropout)
self.dropout3 = nn.Dropout(postnet_dropout)
[docs]
def forward(self, x):
"""Computes the forward pass
Arguments
---------
x: torch.Tensor
a (batch, time_steps, features) input tensor
Returns
-------
output: torch.Tensor
the spectrogram predicted
"""
x = self.conv_pre(x)
x = self.ln1(x).to(x.dtype)
x = self.tanh(x)
x = self.dropout1(x)
for i in range(len(self.convs_intermediate)):
x = self.convs_intermediate[i](x)
x = self.ln2(x).to(x.dtype)
x = self.tanh(x)
x = self.dropout2(x)
x = self.conv_post(x)
x = self.ln3(x).to(x.dtype)
x = self.dropout3(x)
return x
[docs]
class DurationPredictor(nn.Module):
"""Duration predictor layer
Arguments
---------
in_channels: int
input feature dimension for convolution layers
out_channels: int
output feature dimension for convolution layers
kernel_size: int
duration predictor convolution kernel size
dropout: float
dropout probability, 0 by default
n_units: int
Example
-------
>>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2
>>> duration_predictor_layer = DurationPredictor(in_channels=384, out_channels=384, kernel_size=3)
>>> x = torch.randn(3, 400, 384)
>>> mask = torch.ones(3, 400, 384)
>>> y = duration_predictor_layer(x, mask)
>>> y.shape
torch.Size([3, 400, 1])
"""
def __init__(
self, in_channels, out_channels, kernel_size, dropout=0.0, n_units=1
):
super().__init__()
self.conv1 = CNN.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding="same",
)
self.conv2 = CNN.Conv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding="same",
)
self.linear = linear.Linear(n_neurons=n_units, input_size=out_channels)
self.ln1 = LayerNorm(out_channels)
self.ln2 = LayerNorm(out_channels)
self.relu = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
[docs]
def forward(self, x, x_mask):
"""Computes the forward pass
Arguments
---------
x: torch.Tensor
a (batch, time_steps, features) input tensor
x_mask: torch.Tensor
mask of input tensor
Returns
-------
output: torch.Tensor
the duration predictor outputs
"""
x = self.relu(self.conv1(x * x_mask))
x = self.ln1(x).to(x.dtype)
x = self.dropout1(x)
x = self.relu(self.conv2(x * x_mask))
x = self.ln2(x).to(x.dtype)
x = self.dropout2(x)
return self.linear(x * x_mask)
[docs]
class SPNPredictor(nn.Module):
"""
This module for the silent phoneme predictor. It receives phoneme sequences without any silent phoneme token as
input and predicts whether a silent phoneme should be inserted after a position. This is to avoid the issue of fast
pace at inference time due to having no silent phoneme tokens in the input sequence.
Arguments
---------
enc_num_layers: int
number of transformer layers (TransformerEncoderLayer) in encoder
enc_num_head: int
number of multi-head-attention (MHA) heads in encoder transformer layers
enc_d_model: int
the number of expected features in the encoder
enc_ffn_dim: int
the dimension of the feedforward network model
enc_k_dim: int
the dimension of the key
enc_v_dim: int
the dimension of the value
enc_dropout: float
Dropout for the encoder
normalize_before: bool
whether normalization should be applied before or after MHA or FFN in Transformer layers.
ffn_type: str
whether to use convolutional layers instead of feed forward network inside transformer layer
ffn_cnn_kernel_size_list: list of int
conv kernel size of 2 1d-convs if ffn_type is 1dcnn
n_char: int
the number of symbols for the token embedding
padding_idx: int
the index for padding
"""
def __init__(
self,
enc_num_layers,
enc_num_head,
enc_d_model,
enc_ffn_dim,
enc_k_dim,
enc_v_dim,
enc_dropout,
normalize_before,
ffn_type,
ffn_cnn_kernel_size_list,
n_char,
padding_idx,
):
super().__init__()
self.enc_num_head = enc_num_head
self.padding_idx = padding_idx
self.encPreNet = EncoderPreNet(
n_char, padding_idx, out_channels=enc_d_model
)
self.sinusoidal_positional_embed_encoder = PositionalEncoding(
enc_d_model
)
self.spn_encoder = TransformerEncoder(
num_layers=enc_num_layers,
nhead=enc_num_head,
d_ffn=enc_ffn_dim,
d_model=enc_d_model,
kdim=enc_k_dim,
vdim=enc_v_dim,
dropout=enc_dropout,
activation=nn.ReLU,
normalize_before=normalize_before,
ffn_type=ffn_type,
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
)
self.spn_linear = linear.Linear(n_neurons=1, input_size=enc_d_model)
[docs]
def forward(self, tokens, last_phonemes):
"""forward pass for the module
Arguments
---------
tokens: torch.Tensor
input tokens without silent phonemes
last_phonemes: torch.Tensor
indicates if a phoneme at an index is the last phoneme of a word or not
Returns
-------
spn_decision: torch.Tensor
indicates if a silent phoneme should be inserted after a phoneme
"""
token_feats = self.encPreNet(tokens)
last_phonemes = torch.unsqueeze(last_phonemes, 2).repeat(
1, 1, token_feats.shape[2]
)
token_feats = token_feats + last_phonemes
srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx)
srcmask_inverted = (~srcmask).unsqueeze(-1)
pos = self.sinusoidal_positional_embed_encoder(token_feats)
token_feats = torch.add(token_feats, pos) * srcmask_inverted
spn_mask = (
torch.triu(
torch.ones(
token_feats.shape[1],
token_feats.shape[1],
device=token_feats.device,
),
diagonal=1,
)
.bool()
.repeat(self.enc_num_head * token_feats.shape[0], 1, 1)
)
spn_token_feats, _ = self.spn_encoder(
token_feats, src_mask=spn_mask, src_key_padding_mask=srcmask
)
spn_decision = self.spn_linear(spn_token_feats).squeeze(-1)
return spn_decision
[docs]
def infer(self, tokens, last_phonemes):
"""inference function
Arguments
---------
tokens: torch.Tensor
input tokens without silent phonemes
last_phonemes: torch.Tensor
indicates if a phoneme at an index is the last phoneme of a word or not
Returns
-------
spn_decision: torch.Tensor
indicates if a silent phoneme should be inserted after a phoneme
"""
spn_decision = self.forward(tokens, last_phonemes)
spn_decision = torch.sigmoid(spn_decision) > 0.8
return spn_decision
[docs]
class FastSpeech2(nn.Module):
"""The FastSpeech2 text-to-speech model.
This class is the main entry point for the model, which is responsible
for instantiating all submodules, which, in turn, manage the individual
neural network layers
Simplified STRUCTURE: input->token embedding ->encoder ->duration/pitch/energy predictor ->duration
upsampler -> decoder -> output
During training, teacher forcing is used (ground truth durations are used for upsampling)
Arguments
---------
enc_num_layers: int
number of transformer layers (TransformerEncoderLayer) in encoder
enc_num_head: int
number of multi-head-attention (MHA) heads in encoder transformer layers
enc_d_model: int
the number of expected features in the encoder
enc_ffn_dim: int
the dimension of the feedforward network model
enc_k_dim: int
the dimension of the key
enc_v_dim: int
the dimension of the value
enc_dropout: float
Dropout for the encoder
dec_num_layers: int
number of transformer layers (TransformerEncoderLayer) in decoder
dec_num_head: int
number of multi-head-attention (MHA) heads in decoder transformer layers
dec_d_model: int
the number of expected features in the decoder
dec_ffn_dim: int
the dimension of the feedforward network model
dec_k_dim: int
the dimension of the key
dec_v_dim: int
the dimension of the value
dec_dropout: float
dropout for the decoder
normalize_before: bool
whether normalization should be applied before or after MHA or FFN in Transformer layers.
ffn_type: str
whether to use convolutional layers instead of feed forward network inside transformer layer.
ffn_cnn_kernel_size_list: list of int
conv kernel size of 2 1d-convs if ffn_type is 1dcnn
n_char: int
the number of symbols for the token embedding
n_mels: int
number of bins in mel spectrogram
postnet_embedding_dim: int
output feature dimension for convolution layers
postnet_kernel_size: int
postnet convolution kernel size
postnet_n_convolutions: int
number of convolution layers
postnet_dropout: float
dropout probability for postnet
padding_idx: int
the index for padding
dur_pred_kernel_size: int
the convolution kernel size in duration predictor
pitch_pred_kernel_size: int
kernel size for pitch prediction.
energy_pred_kernel_size: int
kernel size for energy prediction.
variance_predictor_dropout: float
dropout probability for variance predictor (duration/pitch/energy)
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2
>>> model = FastSpeech2(
... enc_num_layers=6,
... enc_num_head=2,
... enc_d_model=384,
... enc_ffn_dim=1536,
... enc_k_dim=384,
... enc_v_dim=384,
... enc_dropout=0.1,
... dec_num_layers=6,
... dec_num_head=2,
... dec_d_model=384,
... dec_ffn_dim=1536,
... dec_k_dim=384,
... dec_v_dim=384,
... dec_dropout=0.1,
... normalize_before=False,
... ffn_type='1dcnn',
... ffn_cnn_kernel_size_list=[9, 1],
... n_char=40,
... n_mels=80,
... postnet_embedding_dim=512,
... postnet_kernel_size=5,
... postnet_n_convolutions=5,
... postnet_dropout=0.5,
... padding_idx=0,
... dur_pred_kernel_size=3,
... pitch_pred_kernel_size=3,
... energy_pred_kernel_size=3,
... variance_predictor_dropout=0.5)
>>> inputs = torch.tensor([
... [13, 12, 31, 14, 19],
... [31, 16, 30, 31, 0],
... ])
>>> input_lengths = torch.tensor([5, 4])
>>> durations = torch.tensor([
... [2, 4, 1, 5, 3],
... [1, 2, 4, 3, 0],
... ])
>>> mel_post, postnet_output, predict_durations, predict_pitch, avg_pitch, predict_energy, avg_energy, mel_lens = model(inputs, durations=durations)
>>> mel_post.shape, predict_durations.shape
(torch.Size([2, 15, 80]), torch.Size([2, 5]))
>>> predict_pitch.shape, predict_energy.shape
(torch.Size([2, 5, 1]), torch.Size([2, 5, 1]))
"""
def __init__(
self,
# encoder parameters
enc_num_layers,
enc_num_head,
enc_d_model,
enc_ffn_dim,
enc_k_dim,
enc_v_dim,
enc_dropout,
# decoder parameters
dec_num_layers,
dec_num_head,
dec_d_model,
dec_ffn_dim,
dec_k_dim,
dec_v_dim,
dec_dropout,
normalize_before,
ffn_type,
ffn_cnn_kernel_size_list,
n_char,
n_mels,
postnet_embedding_dim,
postnet_kernel_size,
postnet_n_convolutions,
postnet_dropout,
padding_idx,
dur_pred_kernel_size,
pitch_pred_kernel_size,
energy_pred_kernel_size,
variance_predictor_dropout,
):
super().__init__()
self.enc_num_head = enc_num_head
self.dec_num_head = dec_num_head
self.padding_idx = padding_idx
self.sinusoidal_positional_embed_encoder = PositionalEncoding(
enc_d_model
)
self.sinusoidal_positional_embed_decoder = PositionalEncoding(
dec_d_model
)
self.encPreNet = EncoderPreNet(
n_char, padding_idx, out_channels=enc_d_model
)
self.durPred = DurationPredictor(
in_channels=enc_d_model,
out_channels=enc_d_model,
kernel_size=dur_pred_kernel_size,
dropout=variance_predictor_dropout,
)
self.pitchPred = DurationPredictor(
in_channels=enc_d_model,
out_channels=enc_d_model,
kernel_size=dur_pred_kernel_size,
dropout=variance_predictor_dropout,
)
self.energyPred = DurationPredictor(
in_channels=enc_d_model,
out_channels=enc_d_model,
kernel_size=dur_pred_kernel_size,
dropout=variance_predictor_dropout,
)
self.pitchEmbed = CNN.Conv1d(
in_channels=1,
out_channels=enc_d_model,
kernel_size=pitch_pred_kernel_size,
padding="same",
skip_transpose=True,
)
self.energyEmbed = CNN.Conv1d(
in_channels=1,
out_channels=enc_d_model,
kernel_size=energy_pred_kernel_size,
padding="same",
skip_transpose=True,
)
self.encoder = TransformerEncoder(
num_layers=enc_num_layers,
nhead=enc_num_head,
d_ffn=enc_ffn_dim,
d_model=enc_d_model,
kdim=enc_k_dim,
vdim=enc_v_dim,
dropout=enc_dropout,
activation=nn.ReLU,
normalize_before=normalize_before,
ffn_type=ffn_type,
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
)
self.decoder = TransformerEncoder(
num_layers=dec_num_layers,
nhead=dec_num_head,
d_ffn=dec_ffn_dim,
d_model=dec_d_model,
kdim=dec_k_dim,
vdim=dec_v_dim,
dropout=dec_dropout,
activation=nn.ReLU,
normalize_before=normalize_before,
ffn_type=ffn_type,
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
)
self.linear = linear.Linear(n_neurons=n_mels, input_size=dec_d_model)
self.postnet = PostNet(
n_mel_channels=n_mels,
postnet_embedding_dim=postnet_embedding_dim,
postnet_kernel_size=postnet_kernel_size,
postnet_n_convolutions=postnet_n_convolutions,
postnet_dropout=postnet_dropout,
)
[docs]
def forward(
self,
tokens,
durations=None,
pitch=None,
energy=None,
pace=1.0,
pitch_rate=1.0,
energy_rate=1.0,
):
"""forward pass for training and inference
Arguments
---------
tokens: torch.Tensor
batch of input tokens
durations: torch.Tensor
batch of durations for each token. If it is None, the model will infer on predicted durations
pitch: torch.Tensor
batch of pitch for each frame. If it is None, the model will infer on predicted pitches
energy: torch.Tensor
batch of energy for each frame. If it is None, the model will infer on predicted energies
pace: float
scaling factor for durations
pitch_rate: float
scaling factor for pitches
energy_rate: float
scaling factor for energies
Returns
-------
mel_post: torch.Tensor
mel outputs from the decoder
postnet_output: torch.Tensor
mel outputs from the postnet
predict_durations: torch.Tensor
predicted durations of each token
predict_pitch: torch.Tensor
predicted pitches of each token
avg_pitch: torch.Tensor
target pitches for each token if input pitch is not None
None if input pitch is None
predict_energy: torch.Tensor
predicted energies of each token
avg_energy: torch.Tensor
target energies for each token if input energy is not None
None if input energy is None
mel_length:
predicted lengths of mel spectrograms
"""
srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx)
srcmask_inverted = (~srcmask).unsqueeze(-1)
# prenet & encoder
token_feats = self.encPreNet(tokens)
pos = self.sinusoidal_positional_embed_encoder(token_feats)
token_feats = torch.add(token_feats, pos) * srcmask_inverted
attn_mask = (
srcmask.unsqueeze(-1)
.repeat(self.enc_num_head, 1, token_feats.shape[1])
.permute(0, 2, 1)
.bool()
)
token_feats, _ = self.encoder(
token_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
)
token_feats = token_feats * srcmask_inverted
# duration predictor
predict_durations = self.durPred(token_feats, srcmask_inverted).squeeze(
-1
)
if predict_durations.dim() == 1:
predict_durations = predict_durations.unsqueeze(0)
if durations is None:
dur_pred_reverse_log = torch.clamp(
torch.special.expm1(predict_durations), 0
)
# pitch predictor
avg_pitch = None
predict_pitch = self.pitchPred(token_feats, srcmask_inverted)
# use a pitch rate to adjust the pitch
predict_pitch = predict_pitch * pitch_rate
if pitch is not None:
avg_pitch = average_over_durations(pitch.unsqueeze(1), durations)
pitch = self.pitchEmbed(avg_pitch)
avg_pitch = avg_pitch.permute(0, 2, 1)
else:
pitch = self.pitchEmbed(predict_pitch.permute(0, 2, 1))
pitch = pitch.permute(0, 2, 1)
token_feats = token_feats.add(pitch)
# energy predictor
avg_energy = None
predict_energy = self.energyPred(token_feats, srcmask_inverted)
# use an energy rate to adjust the energy
predict_energy = predict_energy * energy_rate
if energy is not None:
avg_energy = average_over_durations(energy.unsqueeze(1), durations)
energy = self.energyEmbed(avg_energy)
avg_energy = avg_energy.permute(0, 2, 1)
else:
energy = self.energyEmbed(predict_energy.permute(0, 2, 1))
energy = energy.permute(0, 2, 1)
token_feats = token_feats.add(energy)
# upsamples the durations
spec_feats, mel_lens = upsample(
token_feats,
durations if durations is not None else dur_pred_reverse_log,
pace=pace,
)
srcmask = get_mask_from_lengths(torch.tensor(mel_lens))
srcmask = srcmask.to(spec_feats.device)
srcmask_inverted = (~srcmask).unsqueeze(-1)
attn_mask = (
srcmask.unsqueeze(-1)
.repeat(self.dec_num_head, 1, spec_feats.shape[1])
.permute(0, 2, 1)
.bool()
)
# decoder
pos = self.sinusoidal_positional_embed_decoder(spec_feats)
spec_feats = torch.add(spec_feats, pos) * srcmask_inverted
output_mel_feats, memory, *_ = self.decoder(
spec_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
)
# postnet
mel_post = self.linear(output_mel_feats) * srcmask_inverted
postnet_output = self.postnet(mel_post) + mel_post
return (
mel_post,
postnet_output,
predict_durations,
predict_pitch,
avg_pitch,
predict_energy,
avg_energy,
torch.tensor(mel_lens),
)
[docs]
def average_over_durations(values, durs):
"""Average values over durations.
Arguments
---------
values: torch.Tensor
shape: [B, 1, T_de]
durs: torch.Tensor
shape: [B, T_en]
Returns
-------
avg: torch.Tensor
shape: [B, 1, T_en]
"""
durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
values_nonzero_cums = torch.nn.functional.pad(
torch.cumsum(values != 0.0, dim=2), (1, 0)
)
values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0))
bs, length = durs_cums_ends.size()
n_formants = values.size(1)
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, length)
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, length)
values_sums = (
torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)
).float()
values_nelems = (
torch.gather(values_nonzero_cums, 2, dce)
- torch.gather(values_nonzero_cums, 2, dcs)
).float()
avg = torch.where(
values_nelems == 0.0, values_nelems, values_sums / values_nelems
)
return avg
[docs]
def upsample(feats, durs, pace=1.0, padding_value=0.0):
"""upsample encoder output according to durations
Arguments
---------
feats: torch.Tensor
batch of input tokens
durs: torch.Tensor
durations to be used to upsample
pace: float
scaling factor for durations
padding_value: int
padding index
Returns
-------
mel_post: torch.Tensor
mel outputs from the decoder
predict_durations: torch.Tensor
predicted durations for each token
"""
upsampled_mels = [
torch.repeat_interleave(feats[i], (pace * durs[i]).long(), dim=0)
for i in range(len(durs))
]
mel_lens = [mel.shape[0] for mel in upsampled_mels]
padded_upsampled_mels = torch.nn.utils.rnn.pad_sequence(
upsampled_mels, batch_first=True, padding_value=padding_value
)
return padded_upsampled_mels, mel_lens
[docs]
class TextMelCollate:
"""Zero-pads model inputs and targets based on number of frames per step"""
# TODO: Make this more intuitive, use the pipeline
[docs]
def __call__(self, batch):
"""Collate's training batch from normalized text and mel-spectrogram
Arguments
---------
batch: list
[text_normalized, mel_normalized]
Returns
-------
text_padded: torch.Tensor
dur_padded: torch.Tensor
input_lengths: torch.Tensor
mel_padded: torch.Tensor
pitch_padded: torch.Tensor
energy_padded: torch.Tensor
output_lengths: torch.Tensor
len_x: torch.Tensor
labels: torch.Tensor
wavs: torch.Tensor
no_spn_seq_padded: torch.Tensor
spn_labels_padded: torch.Tensor
last_phonemes_padded: torch.Tensor
"""
# TODO: Remove for loops
raw_batch = list(batch)
for i in range(
len(batch)
): # the pipeline return a dictionary with one element
batch[i] = batch[i]["mel_text_pair"]
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
)
max_input_len = input_lengths[0]
# Get max_no_spn_seq_len
no_spn_seq_lengths, no_spn_ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x[-2]) for x in batch]),
dim=0,
descending=True,
)
max_no_spn_seq_len = no_spn_seq_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len)
no_spn_seq_padded = torch.LongTensor(len(batch), max_no_spn_seq_len)
last_phonemes_padded = torch.LongTensor(len(batch), max_no_spn_seq_len)
dur_padded = torch.LongTensor(len(batch), max_input_len)
spn_labels_padded = torch.FloatTensor(len(batch), max_no_spn_seq_len)
text_padded.zero_()
no_spn_seq_padded.zero_()
last_phonemes_padded.zero_()
dur_padded.zero_()
spn_labels_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]][0]
no_spn_seq = batch[ids_sorted_decreasing[i]][-2]
last_phonemes = torch.LongTensor(
batch[ids_sorted_decreasing[i]][-3]
)
dur = batch[ids_sorted_decreasing[i]][1]
spn_labels = torch.LongTensor(batch[ids_sorted_decreasing[i]][-1])
text_padded[i, : text.size(0)] = text
no_spn_seq_padded[i, : no_spn_seq.size(0)] = no_spn_seq
last_phonemes_padded[i, : last_phonemes.size(0)] = last_phonemes
dur_padded[i, : dur.size(0)] = dur
spn_labels_padded[i, : spn_labels.size(0)] = spn_labels
# Right zero-pad mel-spec
num_mels = batch[0][2].size(0)
max_target_len = max([x[2].size(1) for x in batch])
# include mel padded and gate padded
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
mel_padded.zero_()
pitch_padded = torch.FloatTensor(len(batch), max_target_len)
pitch_padded.zero_()
energy_padded = torch.FloatTensor(len(batch), max_target_len)
energy_padded.zero_()
output_lengths = torch.LongTensor(len(batch))
labels, wavs = [], []
for i in range(len(ids_sorted_decreasing)):
idx = ids_sorted_decreasing[i]
mel = batch[idx][2]
pitch = batch[idx][3]
energy = batch[idx][4]
mel_padded[i, :, : mel.size(1)] = mel
pitch_padded[i, : pitch.size(0)] = pitch
energy_padded[i, : energy.size(0)] = energy
output_lengths[i] = mel.size(1)
labels.append(raw_batch[idx]["label"])
wavs.append(raw_batch[idx]["wav"])
# count number of items - characters in text
len_x = [x[5] for x in batch]
len_x = torch.Tensor(len_x)
mel_padded = mel_padded.permute(0, 2, 1)
return (
text_padded,
dur_padded,
input_lengths,
mel_padded,
pitch_padded,
energy_padded,
output_lengths,
len_x,
labels,
wavs,
no_spn_seq_padded,
spn_labels_padded,
last_phonemes_padded,
)
[docs]
class Loss(nn.Module):
"""Loss Computation
Arguments
---------
log_scale_durations: bool
applies logarithm to target durations
ssim_loss_weight: float
weight for ssim loss
duration_loss_weight: float
weight for the duration loss
pitch_loss_weight: float
weight for the pitch loss
energy_loss_weight: float
weight for the energy loss
mel_loss_weight: float
weight for the mel loss
postnet_mel_loss_weight: float
weight for the postnet mel loss
spn_loss_weight: float
weight for spn loss
spn_loss_max_epochs: int
Max number of epochs
"""
def __init__(
self,
log_scale_durations,
ssim_loss_weight,
duration_loss_weight,
pitch_loss_weight,
energy_loss_weight,
mel_loss_weight,
postnet_mel_loss_weight,
spn_loss_weight=1.0,
spn_loss_max_epochs=8,
):
super().__init__()
self.ssim_loss = SSIMLoss()
self.mel_loss = nn.MSELoss()
self.postnet_mel_loss = nn.MSELoss()
self.dur_loss = nn.MSELoss()
self.pitch_loss = nn.MSELoss()
self.energy_loss = nn.MSELoss()
self.log_scale_durations = log_scale_durations
self.ssim_loss_weight = ssim_loss_weight
self.mel_loss_weight = mel_loss_weight
self.postnet_mel_loss_weight = postnet_mel_loss_weight
self.duration_loss_weight = duration_loss_weight
self.pitch_loss_weight = pitch_loss_weight
self.energy_loss_weight = energy_loss_weight
self.spn_loss_weight = spn_loss_weight
self.spn_loss_max_epochs = spn_loss_max_epochs
[docs]
def forward(self, predictions, targets, current_epoch):
"""Computes the value of the loss function and updates stats
Arguments
---------
predictions: tuple
model predictions
targets: tuple
ground truth data
current_epoch: int
The count of the current epoch.
Returns
-------
loss: torch.Tensor
the loss value
"""
(
mel_target,
target_durations,
target_pitch,
target_energy,
mel_length,
phon_len,
spn_labels,
) = targets
assert len(mel_target.shape) == 3
(
mel_out,
postnet_mel_out,
log_durations,
predicted_pitch,
average_pitch,
predicted_energy,
average_energy,
mel_lens,
spn_preds,
) = predictions
predicted_pitch = predicted_pitch.squeeze(-1)
predicted_energy = predicted_energy.squeeze(-1)
target_pitch = average_pitch.squeeze(-1)
target_energy = average_energy.squeeze(-1)
log_durations = log_durations.squeeze(-1)
if self.log_scale_durations:
log_target_durations = torch.log1p(target_durations.float())
# change this to perform batch level using padding mask
for i in range(mel_target.shape[0]):
if i == 0:
mel_loss = self.mel_loss(
mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
postnet_mel_loss = self.postnet_mel_loss(
postnet_mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
dur_loss = self.dur_loss(
log_durations[i, : phon_len[i]],
log_target_durations[i, : phon_len[i]].to(torch.float32),
)
pitch_loss = self.pitch_loss(
predicted_pitch[i, : mel_length[i]],
target_pitch[i, : mel_length[i]].to(torch.float32),
)
energy_loss = self.energy_loss(
predicted_energy[i, : mel_length[i]],
target_energy[i, : mel_length[i]].to(torch.float32),
)
else:
mel_loss = mel_loss + self.mel_loss(
mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
postnet_mel_loss = postnet_mel_loss + self.postnet_mel_loss(
postnet_mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
dur_loss = dur_loss + self.dur_loss(
log_durations[i, : phon_len[i]],
log_target_durations[i, : phon_len[i]].to(torch.float32),
)
pitch_loss = pitch_loss + self.pitch_loss(
predicted_pitch[i, : mel_length[i]],
target_pitch[i, : mel_length[i]].to(torch.float32),
)
energy_loss = energy_loss + self.energy_loss(
predicted_energy[i, : mel_length[i]],
target_energy[i, : mel_length[i]].to(torch.float32),
)
ssim_loss = self.ssim_loss(mel_out, mel_target, mel_length)
mel_loss = torch.div(mel_loss, len(mel_target))
postnet_mel_loss = torch.div(postnet_mel_loss, len(mel_target))
dur_loss = torch.div(dur_loss, len(mel_target))
pitch_loss = torch.div(pitch_loss, len(mel_target))
energy_loss = torch.div(energy_loss, len(mel_target))
spn_loss = bce_loss(spn_preds, spn_labels)
if current_epoch > self.spn_loss_max_epochs:
self.spn_loss_weight = 0
total_loss = (
ssim_loss * self.ssim_loss_weight
+ mel_loss * self.mel_loss_weight
+ postnet_mel_loss * self.postnet_mel_loss_weight
+ dur_loss * self.duration_loss_weight
+ pitch_loss * self.pitch_loss_weight
+ energy_loss * self.energy_loss_weight
+ spn_loss * self.spn_loss_weight
)
loss = {
"total_loss": total_loss,
"ssim_loss": ssim_loss * self.ssim_loss_weight,
"mel_loss": mel_loss * self.mel_loss_weight,
"postnet_mel_loss": postnet_mel_loss * self.postnet_mel_loss_weight,
"dur_loss": dur_loss * self.duration_loss_weight,
"pitch_loss": pitch_loss * self.pitch_loss_weight,
"energy_loss": energy_loss * self.energy_loss_weight,
"spn_loss": spn_loss * self.spn_loss_weight,
}
return loss
[docs]
def mel_spectogram(
sample_rate,
hop_length,
win_length,
n_fft,
n_mels,
f_min,
f_max,
power,
normalized,
min_max_energy_norm,
norm,
mel_scale,
compression,
audio,
):
"""calculates MelSpectrogram for a raw audio signal
Arguments
---------
sample_rate : int
Sample rate of audio signal.
hop_length : int
Length of hop between STFT windows.
win_length : int
Window size.
n_fft : int
Size of FFT.
n_mels : int
Number of mel filterbanks.
f_min : float
Minimum frequency.
f_max : float
Maximum frequency.
power : float
Exponent for the magnitude spectrogram.
normalized : bool
Whether to normalize by magnitude after stft.
min_max_energy_norm : bool
Whether to normalize by min-max
norm : str or None
If "slaney", divide the triangular mel weights by the width of the mel band
mel_scale : str
Scale to use: "htk" or "slaney".
compression : bool
whether to do dynamic range compression
audio : torch.Tensor
input audio signal
Returns
-------
mel : torch.Tensor
rmse : torch.Tensor
"""
from torchaudio import transforms
audio_to_mel = transforms.Spectrogram(
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
power=power,
normalized=normalized,
).to(audio.device)
mel_scale = transforms.MelScale(
sample_rate=sample_rate,
n_stft=n_fft // 2 + 1,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
norm=norm,
mel_scale=mel_scale,
).to(audio.device)
spec = audio_to_mel(audio)
mel = mel_scale(spec)
assert mel.dim() == 2
assert mel.shape[0] == n_mels
rmse = torch.norm(mel, dim=0)
if min_max_energy_norm:
rmse = (rmse - torch.min(rmse)) / (torch.max(rmse) - torch.min(rmse))
if compression:
mel = dynamic_range_compression(mel)
return mel, rmse
[docs]
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""Dynamic range compression for audio signals"""
return torch.log(torch.clamp(x, min=clip_val) * C)
[docs]
class SSIMLoss(torch.nn.Module):
"""SSIM loss as (1 - SSIM)
SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity
"""
def __init__(self):
super().__init__()
self.loss_func = _SSIMLoss()
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
[docs]
def sequence_mask(self, sequence_length, max_len=None):
"""Create a sequence mask for filtering padding in a sequence tensor.
Arguments
---------
sequence_length: torch.Tensor
Sequence lengths.
max_len: int
Maximum sequence length. Defaults to None.
Returns
-------
mask: [B, T_max]
"""
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(
max_len, dtype=sequence_length.dtype, device=sequence_length.device
)
# B x T_max
mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
return mask
[docs]
def sample_wise_min_max(self, x: torch.Tensor, mask: torch.Tensor):
"""Min-Max normalize tensor through first dimension
Arguments
---------
x: torch.Tensor
input tensor [B, D1, D2]
mask: torch.Tensor
input mask [B, D1, 1]
Returns
-------
Normalized tensor
"""
maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True)
minimum = torch.amin(
x.masked_fill(~mask, 1e30), dim=(1, 2), keepdim=True
)
return (x - minimum) / (maximum - minimum + 1e-8)
[docs]
def forward(self, y_hat, y, length):
"""
Arguments
---------
y_hat: torch.Tensor
model prediction values [B, T, D].
y: torch.Tensor
target values [B, T, D].
length: torch.Tensor
length of each sample in a batch for masking.
Returns
-------
loss: Average loss value in range [0, 1] masked by the length.
"""
mask = self.sequence_mask(
sequence_length=length, max_len=y.size(1)
).unsqueeze(2)
y_norm = self.sample_wise_min_max(y, mask)
y_hat_norm = self.sample_wise_min_max(y_hat, mask)
ssim_loss = self.loss_func(
(y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)
)
if ssim_loss.item() > 1.0:
print(
f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0"
)
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
if ssim_loss.item() < 0.0:
print(
f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0"
)
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
return ssim_loss
# Adopted from https://github.com/photosynthesis-team/piq
class _SSIMLoss(_Loss):
"""Creates a criterion that measures the structural similarity index error between
each element in the input x and target y.
Equation link: https://en.wikipedia.org/wiki/Structural_similarity
x and y are tensors of arbitrary shapes with a total of n elements each.
The sum operation still operates over all the elements, and divides by n.
The division by n can be avoided if one sets reduction = sum.
In case of 5D input tensors, complex value is returned as a tensor of size 2.
Arguments
---------
kernel_size: int
By default, the mean and covariance of a pixel is obtained
by convolution with given filter_size.
kernel_sigma: float
Standard deviation for Gaussian kernel.
k1: float
Coefficient related to c1 (see equation in the link above).
k2: float
Coefficient related to c2 (see equation in the link above).
downsample: bool
Perform average pool before SSIM computation (Default: True).
reduction: str
Specifies the reduction type
data_range: Union[int, float]
Maximum value range of images (usually 1.0 or 255).
Example
-------
>>> loss = _SSIMLoss()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()
"""
__constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"]
def __init__(
self,
kernel_size=11,
kernel_sigma=1.5,
k1=0.01,
k2=0.03,
downsample=True,
reduction="mean",
data_range=1.0,
):
super().__init__()
# Generic loss parameters.
self.reduction = reduction
# Loss-specific parameters.
self.kernel_size = kernel_size
# This check might look redundant because kernel size is checked within the ssim function anyway.
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
assert (
kernel_size % 2 == 1
), f"Kernel size must be odd, got [{kernel_size}]"
self.kernel_sigma = kernel_sigma
self.k1 = k1
self.k2 = k2
self.downsample = downsample
self.data_range = data_range
def _reduce(self, x, reduction="mean"):
"""Reduce input in batch dimension if needed.
Arguments
---------
x: torch.Tensor
Tensor with shape (B, *).
reduction: str
Specifies the reduction type:
none | mean | sum (Default: mean)
Returns
-------
Reduced outputs.
"""
if reduction == "none":
return x
if reduction == "mean":
return x.mean(dim=0)
if reduction == "sum":
return x.sum(dim=0)
raise ValueError(
"Unknown reduction. Expected one of {'none', 'mean', 'sum'}"
)
def _validate_input(
self,
tensors,
dim_range=(0, -1),
data_range=(0.0, -1.0),
size_range=None,
):
"""Check if the input satisfies the requirements
Arguments
---------
tensors: torch.Tensor
torch.Tensors to check
dim_range: Tuple[int, int]
Allowed number of dimensions. (min, max)
data_range: Tuple[float, float]
Allowed range of values in tensors. (min, max)
size_range: Tuple[int, int]
Dimensions to include in size comparison. (start_dim, end_dim + 1)
Returns
-------
None
"""
if not __debug__:
return
x = tensors[0]
for t in tensors:
assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}"
assert (
t.device == x.device
), f"Expected tensors to be on {x.device}, got {t.device}"
if size_range is None:
assert (
t.size() == x.size()
), f"Expected tensors with same size, got {t.size()} and {x.size()}"
else:
assert (
t.size()[size_range[0] : size_range[1]]
== x.size()[size_range[0] : size_range[1]]
), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
if dim_range[0] == dim_range[1]:
assert (
t.dim() == dim_range[0]
), f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}"
elif dim_range[0] < dim_range[1]:
assert (
dim_range[0] <= t.dim() <= dim_range[1]
), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
if data_range[0] < data_range[1]:
assert (
data_range[0] <= t.min()
), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}"
assert (
t.max() <= data_range[1]
), f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}"
def gaussian_filter(self, kernel_size, sigma):
"""Returns 2D Gaussian kernel N(0,sigma^2)
Arguments
---------
kernel_size: int
Size of the kernel
sigma: float
Std of the distribution
Returns
-------
gaussian_kernel: torch.Tensor
[1, kernel_size, kernel_size]
"""
coords = torch.arange(kernel_size, dtype=torch.float32)
coords -= (kernel_size - 1) / 2.0
g = coords**2
g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp()
g /= g.sum()
return g.unsqueeze(0)
def _ssim_per_channel(self, x, y, kernel, k1=0.01, k2=0.03):
"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
Arguments
---------
x: torch.Tensor
An input tensor (N, C, H, W).
y: torch.Tensor
A target tensor (N, C, H, W).
kernel: torch.Tensor
2D Gaussian kernel.
k1: float
Algorithm parameter (see equation in the link above).
k2: float
Algorithm parameter (see equation in the link above).
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns
-------
Full Value of Structural Similarity (SSIM) index.
"""
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
raise ValueError(
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
f"Kernel size: {kernel.size()}"
)
c1 = k1**2
c2 = k2**2
n_channels = x.size(1)
mu_x = F.conv2d(
x, weight=kernel, stride=1, padding=0, groups=n_channels
)
mu_y = F.conv2d(
y, weight=kernel, stride=1, padding=0, groups=n_channels
)
mu_xx = mu_x**2
mu_yy = mu_y**2
mu_xy = mu_x * mu_y
sigma_xx = (
F.conv2d(
x**2, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu_xx
)
sigma_yy = (
F.conv2d(
y**2, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu_yy
)
sigma_xy = (
F.conv2d(
x * y, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu_xy
)
# Contrast sensitivity (CS) with alpha = beta = gamma = 1.
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
# Structural similarity (SSIM)
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
ssim_val = ss.mean(dim=(-1, -2))
cs = cs.mean(dim=(-1, -2))
return ssim_val, cs
def _ssim_per_channel_complex(self, x, y, kernel, k1=0.01, k2=0.03):
"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
Arguments
---------
x: torch.Tensor
An input tensor (N, C, H, W, 2).
y: torch.Tensor
A target tensor (N, C, H, W, 2).
kernel: torch.Tensor
2-D gauss kernel.
k1: float
Algorithm parameter (see equation in the link above).
k2: float
Algorithm parameter (see equation in the link above).
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns
-------
Full Value of Complex Structural Similarity (SSIM) index.
"""
n_channels = x.size(1)
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
raise ValueError(
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
f"Kernel size: {kernel.size()}"
)
c1 = k1**2
c2 = k2**2
x_real = x[..., 0]
x_imag = x[..., 1]
y_real = y[..., 0]
y_imag = y[..., 1]
mu1_real = F.conv2d(
x_real, weight=kernel, stride=1, padding=0, groups=n_channels
)
mu1_imag = F.conv2d(
x_imag, weight=kernel, stride=1, padding=0, groups=n_channels
)
mu2_real = F.conv2d(
y_real, weight=kernel, stride=1, padding=0, groups=n_channels
)
mu2_imag = F.conv2d(
y_imag, weight=kernel, stride=1, padding=0, groups=n_channels
)
mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2)
mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2)
mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag
mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real
compensation = 1.0
x_sq = x_real.pow(2) + x_imag.pow(2)
y_sq = y_real.pow(2) + y_imag.pow(2)
x_y_real = x_real * y_real - x_imag * y_imag
x_y_imag = x_real * y_imag + x_imag * y_real
sigma1_sq = (
F.conv2d(
x_sq, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu1_sq
)
sigma2_sq = (
F.conv2d(
y_sq, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu2_sq
)
sigma12_real = (
F.conv2d(
x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu1_mu2_real
)
sigma12_imag = (
F.conv2d(
x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels
)
- mu1_mu2_imag
)
sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1)
mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1)
# Set alpha = beta = gamma = 1.
cs_map = (sigma12 * 2 + c2 * compensation) / (
sigma1_sq.unsqueeze(-1)
+ sigma2_sq.unsqueeze(-1)
+ c2 * compensation
)
ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (
mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation
)
ssim_map = ssim_map * cs_map
ssim_val = ssim_map.mean(dim=(-2, -3))
cs = cs_map.mean(dim=(-2, -3))
return ssim_val, cs
def ssim(
self,
x,
y,
kernel_size=11,
kernel_sigma=1.5,
data_range=1.0,
reduction="mean",
full=False,
downsample=True,
k1=0.01,
k2=0.03,
):
"""Interface of Structural Similarity (SSIM) index.
Inputs supposed to be in range [0, data_range].
To match performance with skimage and tensorflow set downsample = True.
Arguments
---------
x: torch.Tensor
An input tensor (N, C, H, W) or (N, C, H, W, 2).
y: torch.Tensor
A target tensor (N, C, H, W) or (N, C, H, W, 2).
kernel_size: int
The side-length of the sliding window used in comparison. Must be an odd value.
kernel_sigma: float
Sigma of normal distribution.
data_range: Union[int, float]
Maximum value range of images (usually 1.0 or 255).
reduction: str
Specifies the reduction type:
none | mean | sum. Default:mean
full: bool
Return cs map or not.
downsample: bool
Perform average pool before SSIM computation. Default: True
k1: float
Algorithm parameter (see equation in the link above).
k2: float
Algorithm parameter (see equation in the link above).
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns
-------
Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
as a tensor of size 2.
"""
assert (
kernel_size % 2 == 1
), f"Kernel size must be odd, got [{kernel_size}]"
self._validate_input(
[x, y], dim_range=(4, 5), data_range=(0, data_range)
)
x = x / float(data_range)
y = y / float(data_range)
# Averagepool image if the size is large enough
f = max(1, round(min(x.size()[-2:]) / 256))
if (f > 1) and downsample:
x = F.avg_pool2d(x, kernel_size=f)
y = F.avg_pool2d(y, kernel_size=f)
kernel = (
self.gaussian_filter(kernel_size, kernel_sigma)
.repeat(x.size(1), 1, 1, 1)
.to(y)
)
_compute_ssim_per_channel = (
self._ssim_per_channel_complex
if x.dim() == 5
else self._ssim_per_channel
)
ssim_map, cs_map = _compute_ssim_per_channel(
x=x, y=y, kernel=kernel, k1=k1, k2=k2
)
ssim_val = ssim_map.mean(1)
cs = cs_map.mean(1)
ssim_val = self._reduce(ssim_val, reduction)
cs = self._reduce(cs, reduction)
if full:
return [ssim_val, cs]
return ssim_val
def forward(self, x, y):
"""Computation of Structural Similarity (SSIM) index as a loss function.
Arguments
---------
x: torch.Tensor
An input tensor (N, C, H, W) or (N, C, H, W, 2).
y: torch.Tensor
A target tensor (N, C, H, W) or (N, C, H, W, 2).
Returns
-------
Value of SSIM loss to be minimized, i.e 1 - ssim in [0, 1] range. In case of 5D input tensors,
complex value is returned as a tensor of size 2.
"""
score = self.ssim(
x=x,
y=y,
kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma,
downsample=self.downsample,
data_range=self.data_range,
reduction=self.reduction,
full=False,
k1=self.k1,
k2=self.k2,
)
return torch.ones_like(score) - score
[docs]
class TextMelCollateWithAlignment:
"""Zero-pads model inputs and targets based on number of frames per step
result: tuple
a tuple of tensors to be used as inputs/targets
(
text_padded,
dur_padded,
input_lengths,
mel_padded,
output_lengths,
len_x,
labels,
wavs
)
"""
# TODO: Make this more intuitive, use the pipeline
[docs]
def __call__(self, batch):
"""Collate's training batch from normalized text and mel-spectrogram
Arguments
---------
batch: list
[text_normalized, mel_normalized]
Returns
-------
phoneme_padded: torch.Tensor
input_lengths: torch.Tensor
mel_padded: torch.Tensor
pitch_padded: torch.Tensor
energy_padded: torch.Tensor
output_lengths: torch.Tensor
labels: torch.Tensor
wavs: torch.Tensor
"""
# TODO: Remove for loops
raw_batch = list(batch)
for i in range(
len(batch)
): # the pipeline return a dictionary with one element
batch[i] = batch[i]["mel_text_pair"]
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
)
max_input_len = input_lengths[0]
phoneme_padded = torch.LongTensor(len(batch), max_input_len)
phoneme_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
phoneme = batch[ids_sorted_decreasing[i]][0]
phoneme_padded[i, : phoneme.size(0)] = phoneme
# Right zero-pad mel-spec
num_mels = batch[0][1].size(0)
max_target_len = max([x[1].size(1) for x in batch])
# include mel padded and gate padded
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
mel_padded.zero_()
pitch_padded = torch.FloatTensor(len(batch), max_target_len)
pitch_padded.zero_()
energy_padded = torch.FloatTensor(len(batch), max_target_len)
energy_padded.zero_()
output_lengths = torch.LongTensor(len(batch))
labels, wavs = [], []
for i in range(len(ids_sorted_decreasing)):
idx = ids_sorted_decreasing[i]
mel = batch[idx][1]
pitch = batch[idx][2]
energy = batch[idx][3]
mel_padded[i, :, : mel.size(1)] = mel
pitch_padded[i, : pitch.size(0)] = pitch
energy_padded[i, : energy.size(0)] = energy
output_lengths[i] = mel.size(1)
labels.append(raw_batch[idx]["label"])
wavs.append(raw_batch[idx]["wav"])
mel_padded = mel_padded.permute(0, 2, 1)
return (
phoneme_padded,
input_lengths,
mel_padded,
pitch_padded,
energy_padded,
output_lengths,
labels,
wavs,
)
[docs]
def maximum_path_numpy(value, mask):
"""
Monotonic alignment search algorithm, numpy works faster than the torch implementation.
Arguments
---------
value: torch.Tensor
input alignment values [b, t_x, t_y]
mask: torch.Tensor
input alignment mask [b, t_x, t_y]
Returns
-------
path: torch.Tensor
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.FastSpeech2 import maximum_path_numpy
>>> alignment = torch.rand(2, 5, 100)
>>> mask = torch.ones(2, 5, 100)
>>> hard_alignments = maximum_path_numpy(alignment, mask)
"""
max_neg_val = -np.inf # Patch for Sphinx complaint
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool_)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
for j in range(t_y):
v0 = np.pad(
v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val
)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = x_range <= j
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path
[docs]
class AlignmentNetwork(torch.nn.Module):
"""Learns the alignment between the input text
and the spectrogram with Gaussian Attention.
query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment
key -> conv1d -> relu -> conv1d - - - - - - - - - - - -^
Arguments
---------
in_query_channels: int
Number of channels in the query network. Defaults to 80.
in_key_channels: int
Number of channels in the key network. Defaults to 512.
attn_channels: int
Number of inner channels in the attention layers. Defaults to 80.
temperature: float
Temperature for the softmax. Defaults to 0.0005.
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.FastSpeech2 import AlignmentNetwork
>>> aligner = AlignmentNetwork(
... in_query_channels=80,
... in_key_channels=512,
... attn_channels=80,
... temperature=0.0005,
... )
>>> phoneme_feats = torch.rand(2, 512, 20)
>>> mels = torch.rand(2, 80, 100)
>>> alignment_soft, alignment_logprob = aligner(mels, phoneme_feats, None, None)
>>> alignment_soft.shape, alignment_logprob.shape
(torch.Size([2, 1, 100, 20]), torch.Size([2, 1, 100, 20]))
"""
def __init__(
self,
in_query_channels=80,
in_key_channels=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_layer = nn.Sequential(
CNN.Conv1d(
in_channels=in_key_channels,
out_channels=in_key_channels * 2,
kernel_size=3,
padding="same",
bias=True,
skip_transpose=True,
),
torch.nn.ReLU(),
CNN.Conv1d(
in_channels=in_key_channels * 2,
out_channels=attn_channels,
kernel_size=1,
padding="same",
bias=True,
skip_transpose=True,
),
)
self.query_layer = nn.Sequential(
CNN.Conv1d(
in_channels=in_query_channels,
out_channels=in_query_channels * 2,
kernel_size=3,
padding="same",
bias=True,
skip_transpose=True,
),
torch.nn.ReLU(),
CNN.Conv1d(
in_channels=in_query_channels * 2,
out_channels=in_query_channels,
kernel_size=1,
padding="same",
bias=True,
skip_transpose=True,
),
torch.nn.ReLU(),
CNN.Conv1d(
in_channels=in_query_channels,
out_channels=attn_channels,
kernel_size=1,
padding="same",
bias=True,
skip_transpose=True,
),
)
[docs]
def forward(self, queries, keys, mask, attn_prior):
"""Forward pass of the aligner encoder.
Arguments
---------
queries: torch.Tensor
the query tensor [B, C, T_de]
keys: torch.Tensor
the query tensor [B, C_emb, T_en]
mask: torch.Tensor
the query mask[B, T_de]
attn_prior: torch.Tensor
the prior attention tensor [B, 1, T_en, T_de]
Returns
-------
attn: torch.Tensor
soft attention [B, 1, T_en, T_de]
attn_logp: torch.Tensor
log probabilities [B, 1, T_en , T_de]
"""
key_out = self.key_layer(keys)
query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
if attn_prior is not None:
attn_logp = self.log_softmax(attn_logp) + torch.log(
attn_prior[:, None] + 1e-8
)
if mask is not None:
attn_logp.data.masked_fill_(
~mask.bool().unsqueeze(2), -float("inf")
)
attn = self.softmax(attn_logp)
return attn, attn_logp
[docs]
class FastSpeech2WithAlignment(nn.Module):
"""The FastSpeech2 text-to-speech model with internal alignment.
This class is the main entry point for the model, which is responsible
for instantiating all submodules, which, in turn, manage the individual
neural network layers. Certain parts are adopted from the following implementation:
https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/models/forward_tts.py
Simplified STRUCTURE:
input -> token embedding -> encoder -> aligner -> duration/pitch/energy -> upsampler -> decoder -> output
Arguments
---------
enc_num_layers: int
number of transformer layers (TransformerEncoderLayer) in encoder
enc_num_head: int
number of multi-head-attention (MHA) heads in encoder transformer layers
enc_d_model: int
the number of expected features in the encoder
enc_ffn_dim: int
the dimension of the feedforward network model
enc_k_dim: int
the dimension of the key
enc_v_dim: int
the dimension of the value
enc_dropout: float
Dropout for the encoder
in_query_channels: int
Number of channels in the query network.
in_key_channels: int
Number of channels in the key network.
attn_channels: int
Number of inner channels in the attention layers.
temperature: float
Temperature for the softmax.
dec_num_layers: int
number of transformer layers (TransformerEncoderLayer) in decoder
dec_num_head: int
number of multi-head-attention (MHA) heads in decoder transformer layers
dec_d_model: int
the number of expected features in the decoder
dec_ffn_dim: int
the dimension of the feedforward network model
dec_k_dim: int
the dimension of the key
dec_v_dim: int
the dimension of the value
dec_dropout: float
dropout for the decoder
normalize_before: bool
whether normalization should be applied before or after MHA or FFN in Transformer layers.
ffn_type: str
whether to use convolutional layers instead of feed forward network inside transformer layer.
ffn_cnn_kernel_size_list: list of int
conv kernel size of 2 1d-convs if ffn_type is 1dcnn
n_char: int
the number of symbols for the token embedding
n_mels: int
number of bins in mel spectrogram
postnet_embedding_dim: int
output feature dimension for convolution layers
postnet_kernel_size: int
postnet convolution kernel size
postnet_n_convolutions: int
number of convolution layers
postnet_dropout: float
dropout probability for postnet
padding_idx: int
the index for padding
dur_pred_kernel_size: int
the convolution kernel size in duration predictor
pitch_pred_kernel_size: int
kernel size for pitch prediction.
energy_pred_kernel_size: int
kernel size for energy prediction.
variance_predictor_dropout: float
dropout probability for variance predictor (duration/pitch/energy)
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2WithAlignment
>>> model = FastSpeech2WithAlignment(
... enc_num_layers=6,
... enc_num_head=2,
... enc_d_model=384,
... enc_ffn_dim=1536,
... enc_k_dim=384,
... enc_v_dim=384,
... enc_dropout=0.1,
... in_query_channels=80,
... in_key_channels=384,
... attn_channels=80,
... temperature=0.0005,
... dec_num_layers=6,
... dec_num_head=2,
... dec_d_model=384,
... dec_ffn_dim=1536,
... dec_k_dim=384,
... dec_v_dim=384,
... dec_dropout=0.1,
... normalize_before=False,
... ffn_type='1dcnn',
... ffn_cnn_kernel_size_list=[9, 1],
... n_char=40,
... n_mels=80,
... postnet_embedding_dim=512,
... postnet_kernel_size=5,
... postnet_n_convolutions=5,
... postnet_dropout=0.5,
... padding_idx=0,
... dur_pred_kernel_size=3,
... pitch_pred_kernel_size=3,
... energy_pred_kernel_size=3,
... variance_predictor_dropout=0.5)
>>> inputs = torch.tensor([
... [13, 12, 31, 14, 19],
... [31, 16, 30, 31, 0],
... ])
>>> mels = torch.rand(2, 100, 80)
>>> mel_post, postnet_output, durations, predict_pitch, avg_pitch, predict_energy, avg_energy, mel_lens, alignment_durations, alignment_soft, alignment_logprob, alignment_mas = model(inputs, mels)
>>> mel_post.shape, durations.shape
(torch.Size([2, 100, 80]), torch.Size([2, 5]))
>>> predict_pitch.shape, predict_energy.shape
(torch.Size([2, 5, 1]), torch.Size([2, 5, 1]))
>>> alignment_soft.shape, alignment_mas.shape
(torch.Size([2, 100, 5]), torch.Size([2, 100, 5]))
"""
def __init__(
self,
# encoder parameters
enc_num_layers,
enc_num_head,
enc_d_model,
enc_ffn_dim,
enc_k_dim,
enc_v_dim,
enc_dropout,
# aligner parameters
in_query_channels,
in_key_channels,
attn_channels,
temperature,
# decoder parameters
dec_num_layers,
dec_num_head,
dec_d_model,
dec_ffn_dim,
dec_k_dim,
dec_v_dim,
dec_dropout,
normalize_before,
ffn_type,
ffn_cnn_kernel_size_list,
n_char,
n_mels,
postnet_embedding_dim,
postnet_kernel_size,
postnet_n_convolutions,
postnet_dropout,
padding_idx,
dur_pred_kernel_size,
pitch_pred_kernel_size,
energy_pred_kernel_size,
variance_predictor_dropout,
):
super().__init__()
self.enc_num_head = enc_num_head
self.dec_num_head = dec_num_head
self.padding_idx = padding_idx
self.sinusoidal_positional_embed_encoder = PositionalEncoding(
enc_d_model
)
self.sinusoidal_positional_embed_decoder = PositionalEncoding(
dec_d_model
)
self.encPreNet = EncoderPreNet(
n_char, padding_idx, out_channels=enc_d_model
)
self.durPred = DurationPredictor(
in_channels=enc_d_model,
out_channels=enc_d_model,
kernel_size=dur_pred_kernel_size,
dropout=variance_predictor_dropout,
)
self.pitchPred = DurationPredictor(
in_channels=enc_d_model,
out_channels=enc_d_model,
kernel_size=dur_pred_kernel_size,
dropout=variance_predictor_dropout,
)
self.energyPred = DurationPredictor(
in_channels=enc_d_model,
out_channels=enc_d_model,
kernel_size=dur_pred_kernel_size,
dropout=variance_predictor_dropout,
)
self.pitchEmbed = CNN.Conv1d(
in_channels=1,
out_channels=enc_d_model,
kernel_size=pitch_pred_kernel_size,
padding="same",
skip_transpose=True,
)
self.energyEmbed = CNN.Conv1d(
in_channels=1,
out_channels=enc_d_model,
kernel_size=energy_pred_kernel_size,
padding="same",
skip_transpose=True,
)
self.encoder = TransformerEncoder(
num_layers=enc_num_layers,
nhead=enc_num_head,
d_ffn=enc_ffn_dim,
d_model=enc_d_model,
kdim=enc_k_dim,
vdim=enc_v_dim,
dropout=enc_dropout,
activation=nn.ReLU,
normalize_before=normalize_before,
ffn_type=ffn_type,
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
)
self.decoder = TransformerEncoder(
num_layers=dec_num_layers,
nhead=dec_num_head,
d_ffn=dec_ffn_dim,
d_model=dec_d_model,
kdim=dec_k_dim,
vdim=dec_v_dim,
dropout=dec_dropout,
activation=nn.ReLU,
normalize_before=normalize_before,
ffn_type=ffn_type,
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
)
self.linear = linear.Linear(n_neurons=n_mels, input_size=dec_d_model)
self.postnet = PostNet(
n_mel_channels=n_mels,
postnet_embedding_dim=postnet_embedding_dim,
postnet_kernel_size=postnet_kernel_size,
postnet_n_convolutions=postnet_n_convolutions,
postnet_dropout=postnet_dropout,
)
self.aligner = AlignmentNetwork(
in_query_channels=in_query_channels,
in_key_channels=in_key_channels,
attn_channels=attn_channels,
temperature=temperature,
)
def _forward_aligner(self, x, y, x_mask, y_mask):
"""Aligner forward pass.
1. Compute a mask to apply to the attention map.
2. Run the alignment network.
3. Apply MAS (Monotonic alignment search) to compute the hard alignment map.
4. Compute the durations from the hard alignment map.
Arguments
---------
x: torch.Tensor
Input sequence [B, T_en, C_en].
y: torch.Tensor
Output sequence [B, T_de, C_de].
x_mask: torch.Tensor
Input sequence mask [B, 1, T_en].
y_mask: torch.Tensor
Output sequence mask [B, 1, T_de].
Returns
-------
durations: torch.Tensor
Durations from the hard alignment map [B, T_en].
alignment_soft: torch.Tensor
soft alignment potentials [B, T_en, T_de].
alignment_logprob: torch.Tensor
log scale alignment potentials [B, 1, T_de, T_en].
alignment_mas: torch.Tensor
hard alignment map [B, T_en, T_de].
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
alignment_soft, alignment_logprob = self.aligner(
y.transpose(1, 2), x.transpose(1, 2), x_mask, None
)
alignment_mas = maximum_path_numpy(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(),
attn_mask.squeeze(1).contiguous(),
)
durations = torch.sum(alignment_mas, -1).int()
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
return durations, alignment_soft, alignment_logprob, alignment_mas
[docs]
def forward(
self,
tokens,
mel_spectograms=None,
pitch=None,
energy=None,
pace=1.0,
pitch_rate=1.0,
energy_rate=1.0,
):
"""forward pass for training and inference
Arguments
---------
tokens: torch.Tensor
batch of input tokens
mel_spectograms: torch.Tensor
batch of mel_spectograms (used only for training)
pitch: torch.Tensor
batch of pitch for each frame. If it is None, the model will infer on predicted pitches
energy: torch.Tensor
batch of energy for each frame. If it is None, the model will infer on predicted energies
pace: float
scaling factor for durations
pitch_rate: float
scaling factor for pitches
energy_rate: float
scaling factor for energies
Returns
-------
mel_post: torch.Tensor
mel outputs from the decoder
postnet_output: torch.Tensor
mel outputs from the postnet
predict_durations: torch.Tensor
predicted durations of each token
predict_pitch: torch.Tensor
predicted pitches of each token
avg_pitch: torch.Tensor
target pitches for each token if input pitch is not None
None if input pitch is None
predict_energy: torch.Tensor
predicted energies of each token
avg_energy: torch.Tensor
target energies for each token if input energy is not None
None if input energy is None
mel_length:
predicted lengths of mel spectrograms
alignment_durations:
durations from the hard alignment map
alignment_soft: torch.Tensor
soft alignment potentials
alignment_logprob: torch.Tensor
log scale alignment potentials
alignment_mas: torch.Tensor
hard alignment map
"""
srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx)
srcmask_inverted = (~srcmask).unsqueeze(-1)
# encoder
token_feats = self.encPreNet(tokens)
pos = self.sinusoidal_positional_embed_encoder(token_feats)
token_feats = torch.add(token_feats, pos) * srcmask_inverted
attn_mask = (
srcmask.unsqueeze(-1)
.repeat(self.enc_num_head, 1, token_feats.shape[1])
.permute(0, 2, 1)
.bool()
)
token_feats, _ = self.encoder(
token_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
)
token_feats = token_feats * srcmask_inverted
# aligner
alignment_durations = None
alignment_soft = None
alignment_logprob = None
alignment_mas = None
if mel_spectograms is not None:
y_mask = get_key_padding_mask(
mel_spectograms, pad_idx=self.padding_idx
)
y_mask_inverted = (~y_mask).unsqueeze(-1)
(
alignment_durations,
alignment_soft,
alignment_logprob,
alignment_mas,
) = self._forward_aligner(
token_feats,
mel_spectograms,
srcmask_inverted.transpose(1, 2),
y_mask_inverted.transpose(1, 2),
)
alignment_soft = alignment_soft.transpose(1, 2)
alignment_mas = alignment_mas.transpose(1, 2)
# duration predictor
predict_durations = self.durPred(
token_feats, srcmask_inverted
).squeeze()
if predict_durations.dim() == 1:
predict_durations = predict_durations.unsqueeze(0)
predict_durations_reverse_log = torch.clamp(
torch.special.expm1(predict_durations), 0
)
# pitch predictor
avg_pitch = None
predict_pitch = self.pitchPred(token_feats, srcmask_inverted)
# use a pitch rate to adjust the pitch
predict_pitch = predict_pitch * pitch_rate
if pitch is not None:
avg_pitch = average_over_durations(
pitch.unsqueeze(1), alignment_durations
)
pitch = self.pitchEmbed(avg_pitch)
avg_pitch = avg_pitch.permute(0, 2, 1)
else:
pitch = self.pitchEmbed(predict_pitch.permute(0, 2, 1))
pitch = pitch.permute(0, 2, 1)
token_feats = token_feats.add(pitch)
# energy predictor
avg_energy = None
predict_energy = self.energyPred(token_feats, srcmask_inverted)
# use an energy rate to adjust the energy
predict_energy = predict_energy * energy_rate
if energy is not None:
avg_energy = average_over_durations(
energy.unsqueeze(1), alignment_durations
)
energy = self.energyEmbed(avg_energy)
avg_energy = avg_energy.permute(0, 2, 1)
else:
energy = self.energyEmbed(predict_energy.permute(0, 2, 1))
energy = energy.permute(0, 2, 1)
token_feats = token_feats.add(energy)
# upsampling
spec_feats, mel_lens = upsample(
token_feats,
(
alignment_durations
if alignment_durations is not None
else predict_durations_reverse_log
),
pace=pace,
)
srcmask = get_mask_from_lengths(torch.tensor(mel_lens))
srcmask = srcmask.to(spec_feats.device)
srcmask_inverted = (~srcmask).unsqueeze(-1)
attn_mask = (
srcmask.unsqueeze(-1)
.repeat(self.dec_num_head, 1, spec_feats.shape[1])
.permute(0, 2, 1)
.bool()
)
# decoder
pos = self.sinusoidal_positional_embed_decoder(spec_feats)
spec_feats = torch.add(spec_feats, pos) * srcmask_inverted
output_mel_feats, memory, *_ = self.decoder(
spec_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
)
# postnet
mel_post = self.linear(output_mel_feats) * srcmask_inverted
postnet_output = self.postnet(mel_post) + mel_post
return (
mel_post,
postnet_output,
predict_durations,
predict_pitch,
avg_pitch,
predict_energy,
avg_energy,
torch.tensor(mel_lens),
alignment_durations,
alignment_soft,
alignment_logprob,
alignment_mas,
)
[docs]
class LossWithAlignment(nn.Module):
"""Loss computation including internal aligner
Arguments
---------
log_scale_durations: bool
applies logarithm to target durations
ssim_loss_weight: float
weight for the ssim loss
duration_loss_weight: float
weight for the duration loss
pitch_loss_weight: float
weight for the pitch loss
energy_loss_weight: float
weight for the energy loss
mel_loss_weight: float
weight for the mel loss
postnet_mel_loss_weight: float
weight for the postnet mel loss
aligner_loss_weight: float
weight for the alignment loss
binary_alignment_loss_weight: float
weight for the postnet mel loss
binary_alignment_loss_warmup_epochs: int
Number of epochs to gradually increase the impact of binary loss.
binary_alignment_loss_max_epochs: int
From this epoch on the impact of binary loss is ignored.
"""
def __init__(
self,
log_scale_durations,
ssim_loss_weight,
duration_loss_weight,
pitch_loss_weight,
energy_loss_weight,
mel_loss_weight,
postnet_mel_loss_weight,
aligner_loss_weight,
binary_alignment_loss_weight,
binary_alignment_loss_warmup_epochs,
binary_alignment_loss_max_epochs,
):
super().__init__()
self.ssim_loss = SSIMLoss()
self.mel_loss = nn.MSELoss()
self.postnet_mel_loss = nn.MSELoss()
self.dur_loss = nn.MSELoss()
self.pitch_loss = nn.MSELoss()
self.energy_loss = nn.MSELoss()
self.aligner_loss = ForwardSumLoss()
self.binary_alignment_loss = BinaryAlignmentLoss()
self.log_scale_durations = log_scale_durations
self.ssim_loss_weight = ssim_loss_weight
self.mel_loss_weight = mel_loss_weight
self.postnet_mel_loss_weight = postnet_mel_loss_weight
self.duration_loss_weight = duration_loss_weight
self.pitch_loss_weight = pitch_loss_weight
self.energy_loss_weight = energy_loss_weight
self.aligner_loss_weight = aligner_loss_weight
self.binary_alignment_loss_weight = binary_alignment_loss_weight
self.binary_alignment_loss_warmup_epochs = (
binary_alignment_loss_warmup_epochs
)
self.binary_alignment_loss_max_epochs = binary_alignment_loss_max_epochs
[docs]
def forward(self, predictions, targets, current_epoch):
"""Computes the value of the loss function and updates stats
Arguments
---------
predictions: tuple
model predictions
targets: tuple
ground truth data
current_epoch: int
used to determinate the start/end of the binary alignment loss
Returns
-------
loss: torch.Tensor
the loss value
"""
(
mel_target,
target_pitch,
target_energy,
mel_length,
phon_len,
) = targets
assert len(mel_target.shape) == 3
(
mel_out,
postnet_mel_out,
log_durations,
predicted_pitch,
average_pitch,
predicted_energy,
average_energy,
mel_lens,
alignment_durations,
alignment_soft,
alignment_logprob,
alignment_hard,
) = predictions
predicted_pitch = predicted_pitch.squeeze(-1)
predicted_energy = predicted_energy.squeeze(-1)
target_pitch = average_pitch.squeeze(-1)
target_energy = average_energy.squeeze(-1)
log_durations = log_durations.squeeze(-1)
if self.log_scale_durations:
log_target_durations = torch.log1p(alignment_durations.float())
# change this to perform batch level using padding mask
for i in range(mel_target.shape[0]):
if i == 0:
mel_loss = self.mel_loss(
mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
postnet_mel_loss = self.postnet_mel_loss(
postnet_mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
dur_loss = self.dur_loss(
log_durations[i, : phon_len[i]],
log_target_durations[i, : phon_len[i]].to(torch.float32),
)
pitch_loss = self.pitch_loss(
predicted_pitch[i, : mel_length[i]],
target_pitch[i, : mel_length[i]].to(torch.float32),
)
energy_loss = self.energy_loss(
predicted_energy[i, : mel_length[i]],
target_energy[i, : mel_length[i]].to(torch.float32),
)
else:
mel_loss = mel_loss + self.mel_loss(
mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
postnet_mel_loss = postnet_mel_loss + self.postnet_mel_loss(
postnet_mel_out[i, : mel_length[i], :],
mel_target[i, : mel_length[i], :],
)
dur_loss = dur_loss + self.dur_loss(
log_durations[i, : phon_len[i]],
log_target_durations[i, : phon_len[i]].to(torch.float32),
)
pitch_loss = pitch_loss + self.pitch_loss(
predicted_pitch[i, : mel_length[i]],
target_pitch[i, : mel_length[i]].to(torch.float32),
)
energy_loss = energy_loss + self.energy_loss(
predicted_energy[i, : mel_length[i]],
target_energy[i, : mel_length[i]].to(torch.float32),
)
total_loss = 0
loss = {}
ssim_loss = self.ssim_loss(mel_out, mel_target, mel_length)
loss["ssim_loss"] = ssim_loss * self.ssim_loss_weight
mel_loss = torch.div(mel_loss, len(mel_target))
loss["mel_loss"] = mel_loss * self.mel_loss_weight
postnet_mel_loss = torch.div(postnet_mel_loss, len(mel_target))
loss["postnet_mel_loss"] = (
postnet_mel_loss * self.postnet_mel_loss_weight
)
dur_loss = torch.div(dur_loss, len(mel_target))
loss["dur_loss"] = dur_loss * self.duration_loss_weight
pitch_loss = torch.div(pitch_loss, len(mel_target))
loss["pitch_loss"] = pitch_loss * self.pitch_loss_weight
energy_loss = torch.div(energy_loss, len(mel_target))
loss["energy_loss"] = energy_loss * self.energy_loss_weight
if alignment_logprob is not None:
aligner_loss = self.aligner_loss(
alignment_logprob, phon_len, mel_length
)
loss["aligner_loss"] = aligner_loss * self.aligner_loss_weight
if alignment_soft is not None and alignment_hard is not None:
if current_epoch > self.binary_alignment_loss_max_epochs:
binary_loss_warmup_weight = 0
else:
binary_loss_warmup_weight = (
min(
current_epoch
/ self.binary_alignment_loss_warmup_epochs,
1.0,
)
* 1.0
)
binary_alignment_loss = self.binary_alignment_loss(
alignment_hard, alignment_soft
)
loss["binary_alignment_loss"] = (
binary_alignment_loss
* self.binary_alignment_loss_weight
* binary_loss_warmup_weight
)
total_loss = sum(loss.values())
loss["total_loss"] = total_loss
return loss
[docs]
class ForwardSumLoss(nn.Module):
"""CTC alignment loss
Arguments
---------
blank_logprob: pad value
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.FastSpeech2 import ForwardSumLoss
>>> loss_func = ForwardSumLoss()
>>> attn_logprob = torch.rand(2, 1, 100, 5)
>>> key_lens = torch.tensor([5, 5])
>>> query_lens = torch.tensor([100, 100])
>>> loss = loss_func(attn_logprob, key_lens, query_lens)
"""
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
[docs]
def forward(self, attn_logprob, key_lens, query_lens):
"""
Arguments
---------
attn_logprob: torch.Tensor
log scale alignment potentials [B, 1, query_lens, key_lens]
key_lens: torch.Tensor
mel lengths
query_lens: torch.Tensor
phoneme lengths
Returns
-------
total_loss: torch.Tensor
"""
attn_logprob_padded = torch.nn.functional.pad(
input=attn_logprob, pad=(1, 0), value=self.blank_logprob
)
total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
: query_lens[bid], :, : key_lens[bid] + 1
]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss
total_loss = total_loss / attn_logprob.shape[0]
return total_loss
[docs]
class BinaryAlignmentLoss(nn.Module):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
Example
-------
>>> import torch
>>> from speechbrain.lobes.models.FastSpeech2 import BinaryAlignmentLoss
>>> loss_func = BinaryAlignmentLoss()
>>> alignment_hard = torch.randint(0, 2, (2, 100, 5))
>>> alignment_soft = torch.rand(2, 100, 5)
>>> loss = loss_func(alignment_hard, alignment_soft)
"""
def __init__(self):
super().__init__()
[docs]
def forward(self, alignment_hard, alignment_soft):
"""
alignment_hard: torch.Tensor
hard alignment map [B, mel_lens, phoneme_lens]
alignment_soft: torch.Tensor
soft alignment potentials [B, mel_lens, phoneme_lens]
"""
log_sum = torch.log(
torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)
).sum()
return -log_sum / alignment_hard.sum()