Source code for speechbrain.lobes.models.FastSpeech2

"""
Neural network modules for the FastSpeech 2: Fast and High-Quality End-to-End Text to Speech
synthesis model
Authors
* Sathvik Udupa 2022
* Pradnya Kandarkar 2023
* Yingzhi Wang 2023
"""

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from speechbrain.nnet import CNN, linear
from speechbrain.nnet.embedding import Embedding
from speechbrain.lobes.models.transformer.Transformer import (
    TransformerEncoder,
    get_key_padding_mask,
)
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.nnet.losses import bce_loss


[docs]class PositionalEmbedding(nn.Module): """Computation of the positional embeddings. Arguments --------- embed_dim: int dimensionality of the embeddings. """ def __init__(self, embed_dim): super(PositionalEmbedding, self).__init__() self.demb = embed_dim inv_freq = 1 / ( 10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim) ) self.register_buffer("inv_freq", inv_freq)
[docs] def forward(self, seq_len, mask, dtype): """Computes the forward pass Arguments --------- seq_len: int length of the sequence mask: torch.tensor mask applied to the positional embeddings dtype: str dtype of the embeddings Returns ------- pos_emb: torch.Tensor the tensor with positional embeddings """ pos_seq = torch.arange(seq_len, device=mask.device).to(dtype) sinusoid_inp = torch.matmul( torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0) ) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1) return pos_emb[None, :, :] * mask[:, :, None]
[docs]class EncoderPreNet(nn.Module): """Embedding layer for tokens Arguments --------- n_vocab: int size of the dictionary of embeddings blank_id: int padding index out_channels: int the size of each embedding vector Example ------- >>> from speechbrain.nnet.embedding import Embedding >>> from speechbrain.lobes.models.FastSpeech2 import EncoderPreNet >>> encoder_prenet_layer = EncoderPreNet(n_vocab=40, blank_id=0, out_channels=384) >>> x = torch.rand(3, 5) >>> y = encoder_prenet_layer(x) >>> y.shape torch.Size([3, 5, 384]) """ def __init__(self, n_vocab, blank_id, out_channels=512): super().__init__() self.token_embedding = Embedding( num_embeddings=n_vocab, embedding_dim=out_channels, blank_id=blank_id, )
[docs] def forward(self, x): """Computes the forward pass Arguments --------- x: torch.Tensor a (batch, tokens) input tensor Returns ------- output: torch.Tensor the embedding layer output """ self.token_embedding = self.token_embedding.to(x.device) x = self.token_embedding(x) return x
[docs]class PostNet(nn.Module): """ FastSpeech2 Conv Postnet Arguments --------- n_mel_channels: int input feature dimension for convolution layers postnet_embedding_dim: int output feature dimension for convolution layers postnet_kernel_size: int postnet convolution kernal size postnet_n_convolutions: int number of convolution layers postnet_dropout: float dropout probability fot postnet """ def __init__( self, n_mel_channels=80, postnet_embedding_dim=512, postnet_kernel_size=5, postnet_n_convolutions=5, postnet_dropout=0.5, ): super(PostNet, self).__init__() self.conv_pre = CNN.Conv1d( in_channels=n_mel_channels, out_channels=postnet_embedding_dim, kernel_size=postnet_kernel_size, padding="same", ) self.convs_intermedite = nn.ModuleList() for i in range(1, postnet_n_convolutions - 1): self.convs_intermedite.append( CNN.Conv1d( in_channels=postnet_embedding_dim, out_channels=postnet_embedding_dim, kernel_size=postnet_kernel_size, padding="same", ), ) self.conv_post = CNN.Conv1d( in_channels=postnet_embedding_dim, out_channels=n_mel_channels, kernel_size=postnet_kernel_size, padding="same", ) self.tanh = nn.Tanh() self.ln1 = nn.LayerNorm(postnet_embedding_dim) self.ln2 = nn.LayerNorm(postnet_embedding_dim) self.ln3 = nn.LayerNorm(n_mel_channels) self.dropout1 = nn.Dropout(postnet_dropout) self.dropout2 = nn.Dropout(postnet_dropout) self.dropout3 = nn.Dropout(postnet_dropout)
[docs] def forward(self, x): """Computes the forward pass Arguments --------- x: torch.Tensor a (batch, time_steps, features) input tensor Returns ------- output: torch.Tensor the spectrogram predicted """ x = self.conv_pre(x) x = self.ln1(x).to(x.dtype) x = self.tanh(x) x = self.dropout1(x) for i in range(len(self.convs_intermedite)): x = self.convs_intermedite[i](x) x = self.ln2(x).to(x.dtype) x = self.tanh(x) x = self.dropout2(x) x = self.conv_post(x) x = self.ln3(x).to(x.dtype) x = self.dropout3(x) return x
[docs]class DurationPredictor(nn.Module): """Duration predictor layer Arguments --------- in_channels: int input feature dimension for convolution layers out_channels: int output feature dimension for convolution layers kernel_size: int duration predictor convolution kernal size dropout: float dropout probability, 0 by default Example ------- >>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2 >>> duration_predictor_layer = DurationPredictor(in_channels=384, out_channels=384, kernel_size=3) >>> x = torch.randn(3, 400, 384) >>> mask = torch.ones(3, 400, 384) >>> y = duration_predictor_layer(x, mask) >>> y.shape torch.Size([3, 400, 1]) """ def __init__( self, in_channels, out_channels, kernel_size, dropout=0.0, n_units=1 ): super().__init__() self.conv1 = CNN.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same", ) self.conv2 = CNN.Conv1d( in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same", ) self.linear = linear.Linear(n_neurons=n_units, input_size=out_channels) self.ln1 = LayerNorm(out_channels) self.ln2 = LayerNorm(out_channels) self.relu = nn.ReLU() self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout)
[docs] def forward(self, x, x_mask): """Computes the forward pass Arguments --------- x: torch.Tensor a (batch, time_steps, features) input tensor x_mask: torch.Tensor mask of input tensor Returns ------- output: torch.Tensor the duration predictor outputs """ x = self.relu(self.conv1(x * x_mask)) x = self.ln1(x).to(x.dtype) x = self.dropout1(x) x = self.relu(self.conv2(x * x_mask)) x = self.ln2(x).to(x.dtype) x = self.dropout2(x) return self.linear(x * x_mask)
[docs]class SPNPredictor(nn.Module): """ This module for the silent phoneme predictor. It receives phoneme sequences without any silent phoneme token as input and predicts whether a silent phoneme should be inserted after a position. This is to avoid the issue of fast pace at inference time due to having no silent phoneme tokens in the input sequence. Arguments --------- enc_num_layers: int number of transformer layers (TransformerEncoderLayer) in encoder enc_num_head: int number of multi-head-attention (MHA) heads in encoder transformer layers enc_d_model: int the number of expected features in the encoder enc_ffn_dim: int the dimension of the feedforward network model enc_k_dim: int the dimension of the key enc_v_dim: int the dimension of the value enc_dropout: float Dropout for the encoder normalize_before: bool whether normalization should be applied before or after MHA or FFN in Transformer layers. ffn_type: str whether to use convolutional layers instead of feed forward network inside tranformer layer ffn_cnn_kernel_size_list: list of int conv kernel size of 2 1d-convs if ffn_type is 1dcnn n_char: int the number of symbols for the token embedding padding_idx: int the index for padding """ def __init__( self, enc_num_layers, enc_num_head, enc_d_model, enc_ffn_dim, enc_k_dim, enc_v_dim, enc_dropout, normalize_before, ffn_type, ffn_cnn_kernel_size_list, n_char, padding_idx, ): super().__init__() self.enc_num_head = enc_num_head self.padding_idx = padding_idx self.encPreNet = EncoderPreNet( n_char, padding_idx, out_channels=enc_d_model ) self.sinusoidal_positional_embed_encoder = PositionalEmbedding( enc_d_model ) self.spn_encoder = TransformerEncoder( num_layers=enc_num_layers, nhead=enc_num_head, d_ffn=enc_ffn_dim, d_model=enc_d_model, kdim=enc_k_dim, vdim=enc_v_dim, dropout=enc_dropout, activation=nn.ReLU, normalize_before=normalize_before, ffn_type=ffn_type, ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list, ) self.spn_linear = linear.Linear(n_neurons=1, input_size=enc_d_model)
[docs] def forward(self, tokens, last_phonemes): """forward pass for the module Arguments --------- tokens: torch.Tensor input tokens without silent phonemes last_phonemes: torch.Tensor indicates if a phoneme at an index is the last phoneme of a word or not Returns --------- spn_decision: torch.Tensor indicates if a silent phoneme should be inserted after a phoneme """ token_feats = self.encPreNet(tokens) last_phonemes = torch.unsqueeze(last_phonemes, 2).repeat( 1, 1, token_feats.shape[2] ) token_feats = token_feats + last_phonemes srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx) srcmask_inverted = (~srcmask).unsqueeze(-1) pos = self.sinusoidal_positional_embed_encoder( token_feats.shape[1], srcmask, token_feats.dtype ) token_feats = torch.add(token_feats, pos) * srcmask_inverted spn_mask = ( torch.triu( torch.ones( token_feats.shape[1], token_feats.shape[1], device=token_feats.device, ), diagonal=1, ) .bool() .repeat(self.enc_num_head * token_feats.shape[0], 1, 1) ) spn_token_feats, _ = self.spn_encoder( token_feats, src_mask=spn_mask, src_key_padding_mask=srcmask ) spn_decision = self.spn_linear(spn_token_feats).squeeze() return spn_decision
[docs] def infer(self, tokens, last_phonemes): """inference function Arguments --------- tokens: torch.Tensor input tokens without silent phonemes last_phonemes: torch.Tensor indicates if a phoneme at an index is the last phoneme of a word or not Returns --------- spn_decision: torch.Tensor indicates if a silent phoneme should be inserted after a phoneme """ spn_decision = self.forward(tokens, last_phonemes) spn_decision = torch.sigmoid(spn_decision) > 0.8 return spn_decision
[docs]class FastSpeech2(nn.Module): """The FastSpeech2 text-to-speech model. This class is the main entry point for the model, which is responsible for instantiating all submodules, which, in turn, manage the individual neural network layers Simplified STRUCTURE: input->token embedding ->encoder ->duration predictor ->duration upsampler -> decoder -> output During training, teacher forcing is used (ground truth durations are used for upsampling) Arguments --------- #encoder parameters enc_num_layers: int number of transformer layers (TransformerEncoderLayer) in encoder enc_num_head: int number of multi-head-attention (MHA) heads in encoder transformer layers enc_d_model: int the number of expected features in the encoder enc_ffn_dim: int the dimension of the feedforward network model enc_k_dim: int the dimension of the key enc_v_dim: int the dimension of the value enc_dropout: float Dropout for the encoder normalize_before: bool whether normalization should be applied before or after MHA or FFN in Transformer layers. ffn_type: str whether to use convolutional layers instead of feed forward network inside tranformer layer ffn_cnn_kernel_size_list: list of int conv kernel size of 2 1d-convs if ffn_type is 1dcnn #decoder parameters dec_num_layers: int number of transformer layers (TransformerEncoderLayer) in decoder dec_num_head: int number of multi-head-attention (MHA) heads in decoder transformer layers dec_d_model: int the number of expected features in the decoder dec_ffn_dim: int the dimension of the feedforward network model dec_k_dim: int the dimension of the key dec_v_dim: int the dimension of the value dec_dropout: float dropout for the decoder normalize_before: bool whether normalization should be applied before or after MHA or FFN in Transformer layers. ffn_type: str whether to use convolutional layers instead of feed forward network inside tranformer layer. ffn_cnn_kernel_size_list: list of int conv kernel size of 2 1d-convs if ffn_type is 1dcnn n_char: int the number of symbols for the token embedding n_mels: int number of bins in mel spectrogram postnet_embedding_dim: int output feature dimension for convolution layers postnet_kernel_size: int postnet convolution kernal size postnet_n_convolutions: int number of convolution layers postnet_dropout: float dropout probability fot postnet padding_idx: int the index for padding dur_pred_kernel_size: int the convolution kernel size in duration predictor pitch_pred_kernel_size: int kernel size for pitch prediction. energy_pred_kernel_size: int kernel size for energy prediction. variance_predictor_dropout: float dropout probability for variance predictor (duration/pitch/energy) Example ------- >>> import torch >>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2 >>> model = FastSpeech2( ... enc_num_layers=6, ... enc_num_head=2, ... enc_d_model=384, ... enc_ffn_dim=1536, ... enc_k_dim=384, ... enc_v_dim=384, ... enc_dropout=0.1, ... dec_num_layers=6, ... dec_num_head=2, ... dec_d_model=384, ... dec_ffn_dim=1536, ... dec_k_dim=384, ... dec_v_dim=384, ... dec_dropout=0.1, ... normalize_before=False, ... ffn_type='1dcnn', ... ffn_cnn_kernel_size_list=[9, 1], ... n_char=40, ... n_mels=80, ... postnet_embedding_dim=512, ... postnet_kernel_size=5, ... postnet_n_convolutions=5, ... postnet_dropout=0.5, ... padding_idx=0, ... dur_pred_kernel_size=3, ... pitch_pred_kernel_size=3, ... energy_pred_kernel_size=3, ... variance_predictor_dropout=0.5) >>> inputs = torch.tensor([ ... [13, 12, 31, 14, 19], ... [31, 16, 30, 31, 0], ... ]) >>> input_lengths = torch.tensor([5, 4]) >>> durations = torch.tensor([ ... [2, 4, 1, 5, 3], ... [1, 2, 4, 3, 0], ... ]) >>> mel_post, postnet_output, predict_durations, predict_pitch, avg_pitch, predict_energy, avg_energy, mel_lens = model(inputs, durations=durations) >>> mel_post.shape, predict_durations.shape (torch.Size([2, 15, 80]), torch.Size([2, 5])) >>> predict_pitch.shape, predict_energy.shape (torch.Size([2, 5, 1]), torch.Size([2, 5, 1])) """ def __init__( self, enc_num_layers, enc_num_head, enc_d_model, enc_ffn_dim, enc_k_dim, enc_v_dim, enc_dropout, dec_num_layers, dec_num_head, dec_d_model, dec_ffn_dim, dec_k_dim, dec_v_dim, dec_dropout, normalize_before, ffn_type, ffn_cnn_kernel_size_list, n_char, n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions, postnet_dropout, padding_idx, dur_pred_kernel_size, pitch_pred_kernel_size, energy_pred_kernel_size, variance_predictor_dropout, ): super().__init__() self.enc_num_head = enc_num_head self.dec_num_head = dec_num_head self.padding_idx = padding_idx self.sinusoidal_positional_embed_encoder = PositionalEmbedding( enc_d_model ) self.sinusoidal_positional_embed_decoder = PositionalEmbedding( dec_d_model ) self.encPreNet = EncoderPreNet( n_char, padding_idx, out_channels=enc_d_model ) self.durPred = DurationPredictor( in_channels=enc_d_model, out_channels=enc_d_model, kernel_size=dur_pred_kernel_size, dropout=variance_predictor_dropout, ) self.pitchPred = DurationPredictor( in_channels=enc_d_model, out_channels=enc_d_model, kernel_size=dur_pred_kernel_size, dropout=variance_predictor_dropout, ) self.energyPred = DurationPredictor( in_channels=enc_d_model, out_channels=enc_d_model, kernel_size=dur_pred_kernel_size, dropout=variance_predictor_dropout, ) self.pitchEmbed = CNN.Conv1d( in_channels=1, out_channels=enc_d_model, kernel_size=pitch_pred_kernel_size, padding="same", skip_transpose=True, ) self.energyEmbed = CNN.Conv1d( in_channels=1, out_channels=enc_d_model, kernel_size=energy_pred_kernel_size, padding="same", skip_transpose=True, ) self.encoder = TransformerEncoder( num_layers=enc_num_layers, nhead=enc_num_head, d_ffn=enc_ffn_dim, d_model=enc_d_model, kdim=enc_k_dim, vdim=enc_v_dim, dropout=enc_dropout, activation=nn.ReLU, normalize_before=normalize_before, ffn_type=ffn_type, ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list, ) self.decoder = TransformerEncoder( num_layers=dec_num_layers, nhead=dec_num_head, d_ffn=dec_ffn_dim, d_model=dec_d_model, kdim=dec_k_dim, vdim=dec_v_dim, dropout=dec_dropout, activation=nn.ReLU, normalize_before=normalize_before, ffn_type=ffn_type, ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list, ) self.linear = linear.Linear(n_neurons=n_mels, input_size=dec_d_model) self.postnet = PostNet( n_mel_channels=n_mels, postnet_embedding_dim=postnet_embedding_dim, postnet_kernel_size=postnet_kernel_size, postnet_n_convolutions=postnet_n_convolutions, postnet_dropout=postnet_dropout, )
[docs] def forward( self, tokens, durations=None, pitch=None, energy=None, pace=1.0, pitch_rate=1.0, energy_rate=1.0, ): """forward pass for training and inference Arguments --------- tokens: torch.Tensor batch of input tokens durations: torch.Tensor batch of durations for each token. If it is None, the model will infer on predicted durations pitch: torch.Tensor batch of pitch for each frame. If it is None, the model will infer on predicted pitches energy: torch.Tensor batch of energy for each frame. If it is None, the model will infer on predicted energies pace: float scaling factor for durations pitch_rate: float scaling factor for pitches energy_rate: float scaling factor for energies Returns --------- mel_post: torch.Tensor mel outputs from the decoder postnet_output: torch.Tensor mel outputs from the postnet predict_durations: torch.Tensor predicted durations of each token predict_pitch: torch.Tensor predicted pitches of each token avg_pitch: torch.Tensor target pitches for each token if input pitch is not None None if input pitch is None predict_energy: torch.Tensor predicted energies of each token avg_energy: torch.Tensor target energies for each token if input energy is not None None if input energy is None mel_length: predicted lengths of mel spectrograms """ srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx) srcmask_inverted = (~srcmask).unsqueeze(-1) # prenet & encoder token_feats = self.encPreNet(tokens) pos = self.sinusoidal_positional_embed_encoder( token_feats.shape[1], srcmask, token_feats.dtype ) token_feats = torch.add(token_feats, pos) * srcmask_inverted attn_mask = ( srcmask.unsqueeze(-1) .repeat(self.enc_num_head, 1, token_feats.shape[1]) .permute(0, 2, 1) .bool() ) token_feats, _ = self.encoder( token_feats, src_mask=attn_mask, src_key_padding_mask=srcmask ) token_feats = token_feats * srcmask_inverted # duration predictor predict_durations = self.durPred( token_feats, srcmask_inverted ).squeeze() if predict_durations.dim() == 1: predict_durations = predict_durations.unsqueeze(0) if durations is None: dur_pred_reverse_log = torch.clamp( torch.exp(predict_durations) - 1, 0 ) # pitch predictor avg_pitch = None predict_pitch = self.pitchPred(token_feats, srcmask_inverted) # use a pitch rate to adjust the pitch predict_pitch = predict_pitch * pitch_rate if pitch is not None: avg_pitch = average_over_durations(pitch.unsqueeze(1), durations) pitch = self.pitchEmbed(avg_pitch) avg_pitch = avg_pitch.permute(0, 2, 1) else: pitch = self.pitchEmbed(predict_pitch.permute(0, 2, 1)) pitch = pitch.permute(0, 2, 1) token_feats = token_feats.add(pitch) # energy predictor avg_energy = None predict_energy = self.energyPred(token_feats, srcmask_inverted) # use an energy rate to adjust the energy predict_energy = predict_energy * energy_rate if energy is not None: avg_energy = average_over_durations(energy.unsqueeze(1), durations) energy = self.energyEmbed(avg_energy) avg_energy = avg_energy.permute(0, 2, 1) else: energy = self.energyEmbed(predict_energy.permute(0, 2, 1)) energy = energy.permute(0, 2, 1) token_feats = token_feats.add(energy) # upsamples the durations spec_feats, mel_lens = upsample( token_feats, durations if durations is not None else dur_pred_reverse_log, pace=pace, ) srcmask = get_key_padding_mask(spec_feats, pad_idx=self.padding_idx) srcmask_inverted = (~srcmask).unsqueeze(-1) attn_mask = ( srcmask.unsqueeze(-1) .repeat(self.dec_num_head, 1, spec_feats.shape[1]) .permute(0, 2, 1) .bool() ) # decoder pos = self.sinusoidal_positional_embed_decoder( spec_feats.shape[1], srcmask, spec_feats.dtype ) spec_feats = torch.add(spec_feats, pos) * srcmask_inverted output_mel_feats, memory, *_ = self.decoder( spec_feats, src_mask=attn_mask, src_key_padding_mask=srcmask ) # postnet mel_post = self.linear(output_mel_feats) * srcmask_inverted postnet_output = self.postnet(mel_post) + mel_post return ( mel_post, postnet_output, predict_durations, predict_pitch, avg_pitch, predict_energy, avg_energy, torch.tensor(mel_lens), )
[docs]def average_over_durations(values, durs): """Average values over durations. Arguments --------- values: torch.Tensor shape: [B, 1, T_de] durs: torch.Tensor shape: [B, T_en] Returns --------- avg: torch.Tensor shape: [B, 1, T_en] """ durs_cums_ends = torch.cumsum(durs, dim=1).long() durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) values_nonzero_cums = torch.nn.functional.pad( torch.cumsum(values != 0.0, dim=2), (1, 0) ) values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0)) bs, length = durs_cums_ends.size() n_formants = values.size(1) dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, length) dce = durs_cums_ends[:, None, :].expand(bs, n_formants, length) values_sums = ( torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs) ).float() values_nelems = ( torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs) ).float() avg = torch.where( values_nelems == 0.0, values_nelems, values_sums / values_nelems ) return avg
[docs]def upsample(feats, durs, pace=1.0, padding_value=0.0): """upsample encoder ouput according to durations Arguments --------- feats: torch.tensor batch of input tokens durs: torch.tensor durations to be used to upsample pace: float scaling factor for durations padding_value: int padding index Returns --------- mel_post: torch.Tensor mel outputs from the decoder predict_durations: torch.Tensor predicted durations for each token """ upsampled_mels = [ torch.repeat_interleave(feats[i], (pace * durs[i]).long(), dim=0) for i in range(len(durs)) ] mel_lens = [mel.shape[0] for mel in upsampled_mels] padded_upsampled_mels = torch.nn.utils.rnn.pad_sequence( upsampled_mels, batch_first=True, padding_value=padding_value, ) return padded_upsampled_mels, mel_lens
[docs]class TextMelCollate: """ Zero-pads model inputs and targets based on number of frames per step result: tuple a tuple of tensors to be used as inputs/targets ( text_padded, dur_padded, input_lengths, mel_padded, output_lengths, len_x, labels, wavs ) """ # TODO: Make this more intuitive, use the pipeline
[docs] def __call__(self, batch): """Collate's training batch from normalized text and mel-spectrogram Arguments --------- batch: list [text_normalized, mel_normalized] """ # TODO: Remove for loops raw_batch = list(batch) for i in range( len(batch) ): # the pipline return a dictionary wiht one elemnent batch[i] = batch[i]["mel_text_pair"] # Right zero-pad all one-hot text sequences to max input length input_lengths, ids_sorted_decreasing = torch.sort( torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True ) max_input_len = input_lengths[0] # Get max_no_spn_seq_len no_spn_seq_lengths, no_spn_ids_sorted_decreasing = torch.sort( torch.LongTensor([len(x[-2]) for x in batch]), dim=0, descending=True, ) max_no_spn_seq_len = no_spn_seq_lengths[0] text_padded = torch.LongTensor(len(batch), max_input_len) no_spn_seq_padded = torch.LongTensor(len(batch), max_no_spn_seq_len) last_phonemes_padded = torch.LongTensor(len(batch), max_no_spn_seq_len) dur_padded = torch.LongTensor(len(batch), max_input_len) spn_labels_padded = torch.FloatTensor(len(batch), max_no_spn_seq_len) text_padded.zero_() no_spn_seq_padded.zero_() last_phonemes_padded.zero_() dur_padded.zero_() spn_labels_padded.zero_() for i in range(len(ids_sorted_decreasing)): text = batch[ids_sorted_decreasing[i]][0] no_spn_seq = batch[ids_sorted_decreasing[i]][-2] last_phonemes = torch.LongTensor( batch[ids_sorted_decreasing[i]][-3] ) dur = batch[ids_sorted_decreasing[i]][1] spn_labels = torch.LongTensor(batch[ids_sorted_decreasing[i]][-1]) text_padded[i, : text.size(0)] = text no_spn_seq_padded[i, : no_spn_seq.size(0)] = no_spn_seq last_phonemes_padded[i, : last_phonemes.size(0)] = last_phonemes dur_padded[i, : dur.size(0)] = dur spn_labels_padded[i, : spn_labels.size(0)] = spn_labels # Right zero-pad mel-spec num_mels = batch[0][2].size(0) max_target_len = max([x[2].size(1) for x in batch]) # include mel padded and gate padded mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) mel_padded.zero_() pitch_padded = torch.FloatTensor(len(batch), max_target_len) pitch_padded.zero_() energy_padded = torch.FloatTensor(len(batch), max_target_len) energy_padded.zero_() output_lengths = torch.LongTensor(len(batch)) labels, wavs = [], [] for i in range(len(ids_sorted_decreasing)): idx = ids_sorted_decreasing[i] mel = batch[idx][2] pitch = batch[idx][3] energy = batch[idx][4] mel_padded[i, :, : mel.size(1)] = mel pitch_padded[i, : pitch.size(0)] = pitch energy_padded[i, : energy.size(0)] = energy output_lengths[i] = mel.size(1) labels.append(raw_batch[idx]["label"]) wavs.append(raw_batch[idx]["wav"]) # count number of items - characters in text len_x = [x[5] for x in batch] len_x = torch.Tensor(len_x) mel_padded = mel_padded.permute(0, 2, 1) return ( text_padded, dur_padded, input_lengths, mel_padded, pitch_padded, energy_padded, output_lengths, len_x, labels, wavs, no_spn_seq_padded, spn_labels_padded, last_phonemes_padded, )
[docs]class Loss(nn.Module): """Loss Computation Arguments --------- log_scale_durations: bool applies logarithm to target durations duration_loss_weight: int weight for the duration loss pitch_loss_weight: int weight for the pitch loss energy_loss_weight: int weight for the energy loss mel_loss_weight: int weight for the mel loss postnet_mel_loss_weight: int weight for the postnet mel loss """ def __init__( self, log_scale_durations, ssim_loss_weight, duration_loss_weight, pitch_loss_weight, energy_loss_weight, mel_loss_weight, postnet_mel_loss_weight, spn_loss_weight=1.0, spn_loss_max_epochs=8, ): super().__init__() self.ssim_loss = SSIMLoss() self.mel_loss = nn.MSELoss() self.postnet_mel_loss = nn.MSELoss() self.dur_loss = nn.MSELoss() self.pitch_loss = nn.MSELoss() self.energy_loss = nn.MSELoss() self.log_scale_durations = log_scale_durations self.ssim_loss_weight = ssim_loss_weight self.mel_loss_weight = mel_loss_weight self.postnet_mel_loss_weight = postnet_mel_loss_weight self.duration_loss_weight = duration_loss_weight self.pitch_loss_weight = pitch_loss_weight self.energy_loss_weight = energy_loss_weight self.spn_loss_weight = spn_loss_weight self.spn_loss_max_epochs = spn_loss_max_epochs
[docs] def forward(self, predictions, targets, current_epoch): """Computes the value of the loss function and updates stats Arguments --------- predictions: tuple model predictions targets: tuple ground truth data Returns ------- loss: torch.Tensor the loss value """ ( mel_target, target_durations, target_pitch, target_energy, mel_length, phon_len, spn_labels, ) = targets assert len(mel_target.shape) == 3 ( mel_out, postnet_mel_out, log_durations, predicted_pitch, average_pitch, predicted_energy, average_energy, mel_lens, spn_preds, ) = predictions predicted_pitch = predicted_pitch.squeeze() predicted_energy = predicted_energy.squeeze() target_pitch = average_pitch.squeeze() target_energy = average_energy.squeeze() log_durations = log_durations.squeeze() if self.log_scale_durations: log_target_durations = torch.log(target_durations.float() + 1) # change this to perform batch level using padding mask for i in range(mel_target.shape[0]): if i == 0: mel_loss = self.mel_loss( mel_out[i, : mel_length[i], :], mel_target[i, : mel_length[i], :], ) postnet_mel_loss = self.postnet_mel_loss( postnet_mel_out[i, : mel_length[i], :], mel_target[i, : mel_length[i], :], ) dur_loss = self.dur_loss( log_durations[i, : phon_len[i]], log_target_durations[i, : phon_len[i]].to(torch.float32), ) pitch_loss = self.pitch_loss( predicted_pitch[i, : mel_length[i]], target_pitch[i, : mel_length[i]].to(torch.float32), ) energy_loss = self.energy_loss( predicted_energy[i, : mel_length[i]], target_energy[i, : mel_length[i]].to(torch.float32), ) else: mel_loss = mel_loss + self.mel_loss( mel_out[i, : mel_length[i], :], mel_target[i, : mel_length[i], :], ) postnet_mel_loss = postnet_mel_loss + self.postnet_mel_loss( postnet_mel_out[i, : mel_length[i], :], mel_target[i, : mel_length[i], :], ) dur_loss = dur_loss + self.dur_loss( log_durations[i, : phon_len[i]], log_target_durations[i, : phon_len[i]].to(torch.float32), ) pitch_loss = pitch_loss + self.pitch_loss( predicted_pitch[i, : mel_length[i]], target_pitch[i, : mel_length[i]].to(torch.float32), ) energy_loss = energy_loss + self.energy_loss( predicted_energy[i, : mel_length[i]], target_energy[i, : mel_length[i]].to(torch.float32), ) ssim_loss = self.ssim_loss(mel_out, mel_target, mel_length) mel_loss = torch.div(mel_loss, len(mel_target)) postnet_mel_loss = torch.div(postnet_mel_loss, len(mel_target)) dur_loss = torch.div(dur_loss, len(mel_target)) pitch_loss = torch.div(pitch_loss, len(mel_target)) energy_loss = torch.div(energy_loss, len(mel_target)) spn_loss = bce_loss(spn_preds, spn_labels) if current_epoch > self.spn_loss_max_epochs: self.spn_loss_weight = 0 total_loss = ( ssim_loss * self.ssim_loss_weight + mel_loss * self.mel_loss_weight + postnet_mel_loss * self.postnet_mel_loss_weight + dur_loss * self.duration_loss_weight + pitch_loss * self.pitch_loss_weight + energy_loss * self.energy_loss_weight + spn_loss * self.spn_loss_weight ) loss = { "total_loss": total_loss, "ssim_loss": ssim_loss * self.ssim_loss_weight, "mel_loss": mel_loss * self.mel_loss_weight, "postnet_mel_loss": postnet_mel_loss * self.postnet_mel_loss_weight, "dur_loss": dur_loss * self.duration_loss_weight, "pitch_loss": pitch_loss * self.pitch_loss_weight, "energy_loss": energy_loss * self.energy_loss_weight, "spn_loss": spn_loss * self.spn_loss_weight, } return loss
[docs]def mel_spectogram( sample_rate, hop_length, win_length, n_fft, n_mels, f_min, f_max, power, normalized, min_max_energy_norm, norm, mel_scale, compression, audio, ): """calculates MelSpectrogram for a raw audio signal Arguments --------- sample_rate : int Sample rate of audio signal. hop_length : int Length of hop between STFT windows. win_length : int Window size. n_fft : int Size of FFT. n_mels : int Number of mel filterbanks. f_min : float Minimum frequency. f_max : float Maximum frequency. power : float Exponent for the magnitude spectrogram. normalized : bool Whether to normalize by magnitude after stft. norm : str or None If "slaney", divide the triangular mel weights by the width of the mel band mel_scale : str Scale to use: "htk" or "slaney". compression : bool whether to do dynamic range compression audio : torch.tensor input audio signal """ from torchaudio import transforms audio_to_mel = transforms.Spectrogram( hop_length=hop_length, win_length=win_length, n_fft=n_fft, power=power, normalized=normalized, ).to(audio.device) mel_scale = transforms.MelScale( sample_rate=sample_rate, n_stft=n_fft // 2 + 1, n_mels=n_mels, f_min=f_min, f_max=f_max, norm=norm, mel_scale=mel_scale, ).to(audio.device) spec = audio_to_mel(audio) mel = mel_scale(spec) assert mel.dim() == 2 assert mel.shape[0] == n_mels rmse = torch.norm(mel, dim=0) if min_max_energy_norm: rmse = (rmse - torch.min(rmse)) / (torch.max(rmse) - torch.min(rmse)) if compression: mel = dynamic_range_compression(mel) return mel, rmse
[docs]def dynamic_range_compression(x, C=1, clip_val=1e-5): """Dynamic range compression for audio signals """ return torch.log(torch.clamp(x, min=clip_val) * C)
[docs]class SSIMLoss(torch.nn.Module): """SSIM loss as (1 - SSIM) SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity """ def __init__(self): super().__init__() self.loss_func = _SSIMLoss() # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
[docs] def sequence_mask(self, sequence_length, max_len=None): """Create a sequence mask for filtering padding in a sequence tensor. Arguments --------- sequence_length: torch.Tensor Sequence lengths. max_len: int Maximum sequence length. Defaults to None. Returns --------- mask: [B, T_max] """ if max_len is None: max_len = sequence_length.data.max() seq_range = torch.arange( max_len, dtype=sequence_length.dtype, device=sequence_length.device ) # B x T_max mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) return mask
[docs] def sample_wise_min_max(self, x: torch.Tensor, mask: torch.Tensor): """Min-Max normalize tensor through first dimension Arguments --------- x: torch.Tensor input tensor [B, D1, D2] m: torch.Tensor input mask [B, D1, 1] """ maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True) minimum = torch.amin( x.masked_fill(~mask, 1e30), dim=(1, 2), keepdim=True ) return (x - minimum) / (maximum - minimum + 1e-8)
[docs] def forward(self, y_hat, y, length): """ Arguments --------- y_hat: torch.Tensor model prediction values [B, T, D]. y: torch.Tensor target values [B, T, D]. length: torch.Tensor length of each sample in a batch for masking. Returns --------- loss: Average loss value in range [0, 1] masked by the length. """ mask = self.sequence_mask( sequence_length=length, max_len=y.size(1) ).unsqueeze(2) y_norm = self.sample_wise_min_max(y, mask) y_hat_norm = self.sample_wise_min_max(y_hat, mask) ssim_loss = self.loss_func( (y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1) ) if ssim_loss.item() > 1.0: print( f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0" ) ssim_loss = torch.tensor(1.0, device=ssim_loss.device) if ssim_loss.item() < 0.0: print( f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0" ) ssim_loss = torch.tensor(0.0, device=ssim_loss.device) return ssim_loss
# Adopted from https://github.com/photosynthesis-team/piq class _SSIMLoss(_Loss): """Creates a criterion that measures the structural similarity index error between each element in the input x and target y. Equation link: https://en.wikipedia.org/wiki/Structural_similarity x and y are tensors of arbitrary shapes with a total of n elements each. The sum operation still operates over all the elements, and divides by n. The division by n can be avoided if one sets reduction = sum. In case of 5D input tensors, complex value is returned as a tensor of size 2. Arguments --------- kernel_size: int By default, the mean and covariance of a pixel is obtained by convolution with given filter_size. kernel_sigma: float Standard deviation for Gaussian kernel. k1: float Coefficient related to c1 (see equation in the link above). k2: float Coefficient related to c2 (see equation in the link above). downsample: bool Perform average pool before SSIM computation (Default: True). reduction: str Specifies the reduction type data_range: Union[int, float] Maximum value range of images (usually 1.0 or 255). Example ------- >>> loss = _SSIMLoss() >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) >>> y = torch.rand(3, 3, 256, 256) >>> output = loss(x, y) >>> output.backward() """ __constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"] def __init__( self, kernel_size=11, kernel_sigma=1.5, k1=0.01, k2=0.03, downsample=True, reduction="mean", data_range=1.0, ): super().__init__() # Generic loss parameters. self.reduction = reduction # Loss-specific parameters. self.kernel_size = kernel_size # This check might look redundant because kernel size is checked within the ssim function anyway. # However, this check allows to fail fast when the loss is being initialised and training has not been started. assert ( kernel_size % 2 == 1 ), f"Kernel size must be odd, got [{kernel_size}]" self.kernel_sigma = kernel_sigma self.k1 = k1 self.k2 = k2 self.downsample = downsample self.data_range = data_range def _reduce(self, x, reduction="mean"): """Reduce input in batch dimension if needed. Arguments --------- x: torch.Tensor Tensor with shape (B, *). reduction: str Specifies the reduction type: none | mean | sum (Default: mean) """ if reduction == "none": return x if reduction == "mean": return x.mean(dim=0) if reduction == "sum": return x.sum(dim=0) raise ValueError( "Unknown reduction. Expected one of {'none', 'mean', 'sum'}" ) def _validate_input( self, tensors, dim_range=(0, -1), data_range=(0.0, -1.0), size_range=None, ): """Check if the input satisfies the requirements Arguments --------- tensors: torch.Tensor Tensors to check dim_range: Tuple[int, int] Allowed number of dimensions. (min, max) data_range: Tuple[float, float] Allowed range of values in tensors. (min, max) size_range: Tuple[int, int] Dimensions to include in size comparison. (start_dim, end_dim + 1) """ if not __debug__: return x = tensors[0] for t in tensors: assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}" assert ( t.device == x.device ), f"Expected tensors to be on {x.device}, got {t.device}" if size_range is None: assert ( t.size() == x.size() ), f"Expected tensors with same size, got {t.size()} and {x.size()}" else: assert ( t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]] ), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}" if dim_range[0] == dim_range[1]: assert ( t.dim() == dim_range[0] ), f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}" elif dim_range[0] < dim_range[1]: assert ( dim_range[0] <= t.dim() <= dim_range[1] ), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}" if data_range[0] < data_range[1]: assert ( data_range[0] <= t.min() ), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}" assert ( t.max() <= data_range[1] ), f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}" def gaussian_filter(self, kernel_size, sigma): """Returns 2D Gaussian kernel N(0,sigma^2) Arguments --------- size: int Size of the kernel sigma: float Std of the distribution Returns --------- gaussian_kernel: torch.Tensor [1, kernel_size, kernel_size] """ coords = torch.arange(kernel_size, dtype=torch.float32) coords -= (kernel_size - 1) / 2.0 g = coords ** 2 g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp() g /= g.sum() return g.unsqueeze(0) def _ssim_per_channel(self, x, y, kernel, k1=0.01, k2=0.03): """Calculate Structural Similarity (SSIM) index for X and Y per channel. Arguments --------- x: torch.Tensor An input tensor (N, C, H, W). y: torch.Tensor A target tensor (N, C, H, W). kernel: torch.Tensor 2D Gaussian kernel. k1: float Algorithm parameter (see equation in the link above). k2: float Algorithm parameter (see equation in the link above). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns --------- Full Value of Structural Similarity (SSIM) index. """ if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2): raise ValueError( f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " f"Kernel size: {kernel.size()}" ) c1 = k1 ** 2 c2 = k2 ** 2 n_channels = x.size(1) mu_x = F.conv2d( x, weight=kernel, stride=1, padding=0, groups=n_channels ) mu_y = F.conv2d( y, weight=kernel, stride=1, padding=0, groups=n_channels ) mu_xx = mu_x ** 2 mu_yy = mu_y ** 2 mu_xy = mu_x * mu_y sigma_xx = ( F.conv2d( x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu_xx ) sigma_yy = ( F.conv2d( y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu_yy ) sigma_xy = ( F.conv2d( x * y, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu_xy ) # Contrast sensitivity (CS) with alpha = beta = gamma = 1. cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) # Structural similarity (SSIM) ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs ssim_val = ss.mean(dim=(-1, -2)) cs = cs.mean(dim=(-1, -2)) return ssim_val, cs def _ssim_per_channel_complex(self, x, y, kernel, k1=0.01, k2=0.03): """Calculate Structural Similarity (SSIM) index for Complex X and Y per channel. Arguments --------- x: torch.Tensor An input tensor (N, C, H, W, 2). y: torch.Tensor A target tensor (N, C, H, W, 2). kernel: torch.Tensor 2-D gauss kernel. k1: float Algorithm parameter (see equation in the link above). k2: float Algorithm parameter (see equation in the link above). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Full Value of Complex Structural Similarity (SSIM) index. """ n_channels = x.size(1) if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2): raise ValueError( f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " f"Kernel size: {kernel.size()}" ) c1 = k1 ** 2 c2 = k2 ** 2 x_real = x[..., 0] x_imag = x[..., 1] y_real = y[..., 0] y_imag = y[..., 1] mu1_real = F.conv2d( x_real, weight=kernel, stride=1, padding=0, groups=n_channels ) mu1_imag = F.conv2d( x_imag, weight=kernel, stride=1, padding=0, groups=n_channels ) mu2_real = F.conv2d( y_real, weight=kernel, stride=1, padding=0, groups=n_channels ) mu2_imag = F.conv2d( y_imag, weight=kernel, stride=1, padding=0, groups=n_channels ) mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2) mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2) mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real compensation = 1.0 x_sq = x_real.pow(2) + x_imag.pow(2) y_sq = y_real.pow(2) + y_imag.pow(2) x_y_real = x_real * y_real - x_imag * y_imag x_y_imag = x_real * y_imag + x_imag * y_real sigma1_sq = ( F.conv2d( x_sq, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu1_sq ) sigma2_sq = ( F.conv2d( y_sq, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu2_sq ) sigma12_real = ( F.conv2d( x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu1_mu2_real ) sigma12_imag = ( F.conv2d( x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels ) - mu1_mu2_imag ) sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1) mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1) # Set alpha = beta = gamma = 1. cs_map = (sigma12 * 2 + c2 * compensation) / ( sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation ) ssim_map = (mu1_mu2 * 2 + c1 * compensation) / ( mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation ) ssim_map = ssim_map * cs_map ssim_val = ssim_map.mean(dim=(-2, -3)) cs = cs_map.mean(dim=(-2, -3)) return ssim_val, cs def ssim( self, x, y, kernel_size=11, kernel_sigma=1.5, data_range=1.0, reduction="mean", full=False, downsample=True, k1=0.01, k2=0.03, ): """Interface of Structural Similarity (SSIM) index. Inputs supposed to be in range [0, data_range]. To match performance with skimage and tensorflow set downsample = True. Arguments --------- x: torch.Tensor An input tensor (N, C, H, W) or (N, C, H, W, 2). y: torch.Tensor A target tensor (N, C, H, W) or (N, C, H, W, 2). kernel_size: int The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: float Sigma of normal distribution. data_range: Union[int, float] Maximum value range of images (usually 1.0 or 255). reduction: str Specifies the reduction type: none | mean | sum. Default:mean full: bool Return cs map or not. downsample: bool Perform average pool before SSIM computation. Default: True k1: float Algorithm parameter (see equation in the link above). k2: float Algorithm parameter (see equation in the link above). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns --------- Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. """ assert ( kernel_size % 2 == 1 ), f"Kernel size must be odd, got [{kernel_size}]" self._validate_input( [x, y], dim_range=(4, 5), data_range=(0, data_range) ) x = x / float(data_range) y = y / float(data_range) # Averagepool image if the size is large enough f = max(1, round(min(x.size()[-2:]) / 256)) if (f > 1) and downsample: x = F.avg_pool2d(x, kernel_size=f) y = F.avg_pool2d(y, kernel_size=f) kernel = ( self.gaussian_filter(kernel_size, kernel_sigma) .repeat(x.size(1), 1, 1, 1) .to(y) ) _compute_ssim_per_channel = ( self._ssim_per_channel_complex if x.dim() == 5 else self._ssim_per_channel ) ssim_map, cs_map = _compute_ssim_per_channel( x=x, y=y, kernel=kernel, k1=k1, k2=k2 ) ssim_val = ssim_map.mean(1) cs = cs_map.mean(1) ssim_val = self._reduce(ssim_val, reduction) cs = self._reduce(cs, reduction) if full: return [ssim_val, cs] return ssim_val def forward(self, x, y): """Computation of Structural Similarity (SSIM) index as a loss function. Arguments --------- x: torch.Tensor An input tensor (N, C, H, W) or (N, C, H, W, 2). y: torch.Tensor A target tensor (N, C, H, W) or (N, C, H, W, 2). Returns --------- Value of SSIM loss to be minimized, i.e 1 - ssim in [0, 1] range. In case of 5D input tensors, complex value is returned as a tensor of size 2. """ score = self.ssim( x=x, y=y, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, downsample=self.downsample, data_range=self.data_range, reduction=self.reduction, full=False, k1=self.k1, k2=self.k2, ) return torch.ones_like(score) - score