Source code for speechbrain.lobes.models.HifiGAN

"""
Neural network modules for the HiFi-GAN: Generative Adversarial Networks for
Efficient and High Fidelity Speech Synthesis

For more details: https://arxiv.org/pdf/2010.05646.pdf

Authors
 * Jarod Duret 2021
 * Yingzhi WANG 2022
"""

# Adapted from https://github.com/jik876/hifi-gan/ and https://github.com/coqui-ai/TTS/
# MIT License

# Copyright (c) 2020 Jungil Kong

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import torch.nn.functional as F
import torch.nn as nn
import speechbrain as sb
from speechbrain.nnet.CNN import Conv1d, ConvTranspose1d, Conv2d
from torchaudio import transforms

LRELU_SLOPE = 0.1


[docs] def dynamic_range_compression(x, C=1, clip_val=1e-5): """Dynamique range compression for audio signals""" return torch.log(torch.clamp(x, min=clip_val) * C)
[docs] def mel_spectogram( sample_rate, hop_length, win_length, n_fft, n_mels, f_min, f_max, power, normalized, norm, mel_scale, compression, audio, ): """calculates MelSpectrogram for a raw audio signal Arguments --------- sample_rate : int Sample rate of audio signal. hop_length : int Length of hop between STFT windows. win_length : int Window size. n_fft : int Size of FFT. n_mels : int Number of mel filterbanks. f_min : float Minimum frequency. f_max : float Maximum frequency. power : float Exponent for the magnitude spectrogram. normalized : bool Whether to normalize by magnitude after stft. norm : str or None If "slaney", divide the triangular mel weights by the width of the mel band mel_scale : str Scale to use: "htk" or "slaney". compression : bool whether to do dynamic range compression audio : torch.tensor input audio signal """ audio_to_mel = transforms.MelSpectrogram( sample_rate=sample_rate, hop_length=hop_length, win_length=win_length, n_fft=n_fft, n_mels=n_mels, f_min=f_min, f_max=f_max, power=power, normalized=normalized, norm=norm, mel_scale=mel_scale, ).to(audio.device) mel = audio_to_mel(audio) if compression: mel = dynamic_range_compression(mel) return mel
[docs] def process_duration(code, code_feat): """ Process a given batch of code to extract consecutive unique elements and their associated features. Parameters ---------- code : torch.Tensor (batch, time) Tensor of code indices. code_feat : torch.Tensor (batch, time, channel) Tensor of code features. Returns ------- uniq_code_feat_filtered : torch.Tensor (batch, time) Features of consecutive unique codes. mask : torch.Tensor (batch, time) Padding mask for the unique codes. uniq_code_count : torch.Tensor (n) Count of unique codes. Example ------- >>> code = torch.IntTensor([[40, 18, 18, 10]]) >>> code_feat = torch.rand([1, 4, 128]) >>> out_tensor, mask, uniq_code = process_duration(code, code_feat) >>> out_tensor.shape torch.Size([1, 1, 128]) >>> mask.shape torch.Size([1, 1]) >>> uniq_code.shape torch.Size([1]) """ uniq_code_count = [] uniq_code_feat = [] for i in range(code.size(0)): _, count = torch.unique_consecutive(code[i, :], return_counts=True) if len(count) > 2: # remove first and last code as segment sampling may cause incomplete segment length uniq_code_count.append(count[1:-1]) uniq_code_idx = count.cumsum(dim=0)[:-2] else: uniq_code_count.append(count) uniq_code_idx = count.cumsum(dim=0) - 1 uniq_code_feat.append( code_feat[i, uniq_code_idx, :].view(-1, code_feat.size(2)) ) uniq_code_count = torch.cat(uniq_code_count) # collate max_len = max(feat.size(0) for feat in uniq_code_feat) uniq_code_feat_filtered = uniq_code_feat[0].new_zeros( (len(uniq_code_feat), max_len, uniq_code_feat[0].size(1)) ) mask = torch.arange(max_len).repeat(len(uniq_code_feat), 1) for i, v in enumerate(uniq_code_feat): uniq_code_feat_filtered[i, : v.size(0)] = v mask[i, :] = mask[i, :] < v.size(0) return uniq_code_feat_filtered, mask.bool(), uniq_code_count.float()
################################## # Generator ##################################
[docs] class ResBlock1(torch.nn.Module): """ Residual Block Type 1, which has 3 convolutional layers in each convolution block. Arguments --------- channels : int number of hidden channels for the convolutional layers. kernel_size : int size of the convolution filter in each layer. dilations : list list of dilation value for each conv layer in a block. """ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super().__init__() self.convs1 = nn.ModuleList( [ Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=dilation[0], padding="same", skip_transpose=True, weight_norm=True, ), Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=dilation[1], padding="same", skip_transpose=True, weight_norm=True, ), Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=dilation[2], padding="same", skip_transpose=True, weight_norm=True, ), ] ) self.convs2 = nn.ModuleList( [ Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=1, padding="same", skip_transpose=True, weight_norm=True, ), Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=1, padding="same", skip_transpose=True, weight_norm=True, ), Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=1, padding="same", skip_transpose=True, weight_norm=True, ), ] )
[docs] def forward(self, x): """Returns the output of ResBlock1 Arguments --------- x : torch.Tensor (batch, channel, time) input tensor. """ for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) xt = c1(xt) xt = F.leaky_relu(xt, LRELU_SLOPE) xt = c2(xt) x = xt + x return x
[docs] def remove_weight_norm(self): """This functions removes weight normalization during inference.""" for l in self.convs1: l.remove_weight_norm() for l in self.convs2: l.remove_weight_norm()
[docs] class ResBlock2(torch.nn.Module): """ Residual Block Type 2, which has 2 convolutional layers in each convolution block. Arguments --------- channels : int number of hidden channels for the convolutional layers. kernel_size : int size of the convolution filter in each layer. dilations : list list of dilation value for each conv layer in a block. """ def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super().__init__() self.convs = nn.ModuleList( [ Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=dilation[0], padding="same", skip_transpose=True, weight_norm=True, ), Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, dilation=dilation[1], padding="same", skip_transpose=True, weight_norm=True, ), ] )
[docs] def forward(self, x): """Returns the output of ResBlock1 Arguments --------- x : torch.Tensor (batch, channel, time) input tensor. """ for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) xt = c(xt) x = xt + x return x
[docs] def remove_weight_norm(self): """This functions removes weight normalization during inference.""" for l in self.convs: l.remove_weight_norm()
[docs] class HifiganGenerator(torch.nn.Module): """HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) Arguments --------- in_channels : int number of input tensor channels. out_channels : int number of output tensor channels. resblock_type : str type of the `ResBlock`. '1' or '2'. resblock_dilation_sizes : List[List[int]] list of dilation values in each layer of a `ResBlock`. resblock_kernel_sizes : List[int] list of kernel sizes for each `ResBlock`. upsample_kernel_sizes : List[int] list of kernel sizes for each transposed convolution. upsample_initial_channel : int number of channels for the first upsampling layer. This is divided by 2 for each consecutive upsampling layer. upsample_factors : List[int] upsampling factors (stride) for each upsampling layer. inference_padding : int constant padding applied to the input at inference time. Defaults to 5. Example ------- >>> inp_tensor = torch.rand([4, 80, 33]) >>> hifigan_generator= HifiganGenerator( ... in_channels = 80, ... out_channels = 1, ... resblock_type = "1", ... resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], ... resblock_kernel_sizes = [3, 7, 11], ... upsample_kernel_sizes = [16, 16, 4, 4], ... upsample_initial_channel = 512, ... upsample_factors = [8, 8, 2, 2], ... ) >>> out_tensor = hifigan_generator(inp_tensor) >>> out_tensor.shape torch.Size([4, 1, 8448]) """ def __init__( self, in_channels, out_channels, resblock_type, resblock_dilation_sizes, resblock_kernel_sizes, upsample_kernel_sizes, upsample_initial_channel, upsample_factors, inference_padding=5, cond_channels=0, conv_post_bias=True, ): super().__init__() self.inference_padding = inference_padding self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_factors) # initial upsampling layers self.conv_pre = Conv1d( in_channels=in_channels, out_channels=upsample_initial_channel, kernel_size=7, stride=1, padding="same", skip_transpose=True, weight_norm=True, ) resblock = ResBlock1 if resblock_type == "1" else ResBlock2 # upsampling layers self.ups = nn.ModuleList() for i, (u, k) in enumerate( zip(upsample_factors, upsample_kernel_sizes) ): self.ups.append( ConvTranspose1d( in_channels=upsample_initial_channel // (2 ** i), out_channels=upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=(k - u) // 2, skip_transpose=True, weight_norm=True, ) ) # MRF blocks self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for _, (k, d) in enumerate( zip(resblock_kernel_sizes, resblock_dilation_sizes) ): self.resblocks.append(resblock(ch, k, d)) # post convolution layer self.conv_post = Conv1d( in_channels=ch, out_channels=1, kernel_size=7, stride=1, padding="same", skip_transpose=True, bias=conv_post_bias, weight_norm=True, ) if cond_channels > 0: self.cond_layer = Conv1d( in_channels=cond_channels, out_channels=upsample_initial_channel, kernel_size=1, )
[docs] def forward(self, x, g=None): """ Arguments --------- x : torch.Tensor (batch, channel, time) feature input tensor. g : torch.Tensor (batch, 1, time) global conditioning input tensor. """ o = self.conv_pre(x) if hasattr(self, "cond_layer"): o = o + self.cond_layer(g) for i in range(self.num_upsamples): o = F.leaky_relu(o, LRELU_SLOPE) o = self.ups[i](o) z_sum = None for j in range(self.num_kernels): if z_sum is None: z_sum = self.resblocks[i * self.num_kernels + j](o) else: z_sum += self.resblocks[i * self.num_kernels + j](o) o = z_sum / self.num_kernels o = F.leaky_relu(o) o = self.conv_post(o) o = torch.tanh(o) return o
[docs] def remove_weight_norm(self): """This functions removes weight normalization during inference.""" for l in self.ups: l.remove_weight_norm() for l in self.resblocks: l.remove_weight_norm() self.conv_pre.remove_weight_norm() self.conv_post.remove_weight_norm()
[docs] @torch.no_grad() def inference(self, c, padding=True): """The inference function performs a padding and runs the forward method. Arguments --------- x : torch.Tensor (batch, channel, time) feature input tensor. """ if padding: c = torch.nn.functional.pad( c, (self.inference_padding, self.inference_padding), "replicate" ) return self.forward(c)
[docs] class VariancePredictor(nn.Module): """Variance predictor inspired from FastSpeech2 Arguments --------- encoder_embed_dim : int number of input tensor channels. var_pred_hidden_dim : int size of hidden channels for the convolutional layers. var_pred_kernel_size : int size of the convolution filter in each layer. var_pred_dropout : float dropout probability of each layer. Example ------- >>> inp_tensor = torch.rand([4, 80, 128]) >>> duration_predictor = VariancePredictor( ... encoder_embed_dim = 128, ... var_pred_hidden_dim = 128, ... var_pred_kernel_size = 3, ... var_pred_dropout = 0.5, ... ) >>> out_tensor = duration_predictor (inp_tensor) >>> out_tensor.shape torch.Size([4, 80]) """ def __init__( self, encoder_embed_dim, var_pred_hidden_dim, var_pred_kernel_size, var_pred_dropout, ): super().__init__() self.conv1 = nn.Sequential( Conv1d( in_channels=encoder_embed_dim, out_channels=var_pred_hidden_dim, kernel_size=var_pred_kernel_size, padding="same", skip_transpose=True, weight_norm=True, ), nn.ReLU(), ) self.dropout = var_pred_dropout self.conv2 = nn.Sequential( Conv1d( in_channels=var_pred_hidden_dim, out_channels=var_pred_hidden_dim, kernel_size=var_pred_kernel_size, padding="same", skip_transpose=True, weight_norm=True, ), nn.ReLU(), ) self.proj = nn.Linear(var_pred_hidden_dim, 1)
[docs] def forward(self, x): """ Arguments --------- x : torch.Tensor (batch, channel, time) feature input tensor. """ x = self.conv1(x.transpose(1, 2)).transpose(1, 2) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x.transpose(1, 2)).transpose(1, 2) x = F.dropout(x, p=self.dropout, training=self.training) return self.proj(x).squeeze(dim=2)
[docs] class UnitHifiganGenerator(HifiganGenerator): """Unit HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) Arguments --------- in_channels : int number of input tensor channels. out_channels : int number of output tensor channels. resblock_type : str type of the `ResBlock`. '1' or '2'. resblock_dilation_sizes : List[List[int]] list of dilation values in each layer of a `ResBlock`. resblock_kernel_sizes : List[int] list of kernel sizes for each `ResBlock`. upsample_kernel_sizes : List[int] list of kernel sizes for each transposed convolution. upsample_initial_channel : int number of channels for the first upsampling layer. This is divided by 2 for each consecutive upsampling layer. upsample_factors : List[int] upsampling factors (stride) for each upsampling layer. inference_padding : int constant padding applied to the input at inference time. Defaults to 5. num_embeddings : int size of the dictionary of embeddings. embedding_dim : int size of each embedding vector. duration_predictor : bool enable duration predictor module. var_pred_hidden_dim : int size of hidden channels for the convolutional layers of the duration predictor. var_pred_kernel_size : int size of the convolution filter in each layer of the duration predictor. var_pred_dropout : float dropout probability of each layer in the duration predictor. Example ------- >>> inp_tensor = torch.randint(0, 100, (4, 10)) >>> unit_hifigan_generator= UnitHifiganGenerator( ... in_channels = 128, ... out_channels = 1, ... resblock_type = "1", ... resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], ... resblock_kernel_sizes = [3, 7, 11], ... upsample_kernel_sizes = [11, 8, 8, 4, 4], ... upsample_initial_channel = 512, ... upsample_factors = [5, 4, 4, 2, 2], ... num_embeddings = 100, ... embedding_dim = 128, ... duration_predictor = True, ... var_pred_hidden_dim = 128, ... var_pred_kernel_size = 3, ... var_pred_dropout = 0.5, ... ) >>> out_tensor, _ = unit_hifigan_generator(inp_tensor) >>> out_tensor.shape torch.Size([4, 1, 3200]) """ def __init__( self, in_channels, out_channels, resblock_type, resblock_dilation_sizes, resblock_kernel_sizes, upsample_kernel_sizes, upsample_initial_channel, upsample_factors, inference_padding=5, cond_channels=0, conv_post_bias=True, num_embeddings=100, embedding_dim=128, duration_predictor=False, var_pred_hidden_dim=128, var_pred_kernel_size=3, var_pred_dropout=0.5, ): super().__init__( in_channels, out_channels, resblock_type, resblock_dilation_sizes, resblock_kernel_sizes, upsample_kernel_sizes, upsample_initial_channel, upsample_factors, inference_padding, cond_channels, conv_post_bias, ) self.unit_embedding = torch.nn.Embedding(num_embeddings, embedding_dim) self.duration_predictor = duration_predictor if duration_predictor: self.var_predictor = VariancePredictor( embedding_dim, var_pred_hidden_dim, var_pred_kernel_size, var_pred_dropout, )
[docs] def forward(self, x, g=None): """ Arguments --------- x : torch.Tensor (batch, time) feature input tensor. g : torch.Tensor (batch, 1, time) global conditioning input tensor. """ u = self.unit_embedding(x).transpose(1, 2) log_dur = None log_dur_pred = None if self.duration_predictor: uniq_code_feat, uniq_code_mask, dur = process_duration( x, u.transpose(1, 2) ) log_dur_pred = self.var_predictor(uniq_code_feat) log_dur_pred = log_dur_pred[uniq_code_mask] log_dur = torch.log(dur + 1) return super().forward(u), (log_dur_pred, log_dur)
[docs] @torch.no_grad() def inference(self, x): """The inference function performs duration prediction and runs the forward method. Arguments --------- x : torch.Tensor (batch, time) feature input tensor. """ x = self.unit_embedding(x).transpose(1, 2) if self.duration_predictor: assert ( x.size(0) == 1 ), "only support single sample batch in inference" log_dur_pred = self.var_predictor(x.transpose(1, 2)) dur_out = torch.clamp( torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1 ) # B x C x T x = torch.repeat_interleave(x, dur_out.view(-1), dim=2) return super().forward(x)
################################## # DISCRIMINATOR ##################################
[docs] class DiscriminatorP(torch.nn.Module): """HiFiGAN Periodic Discriminator Takes every Pth value from the input waveform and applied a stack of convoluations. Note: if period is 2 waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat Arguments --------- x : torch.Tensor (batch, 1, time) input waveform. """ def __init__(self, period, kernel_size=5, stride=3): super().__init__() self.period = period self.convs = nn.ModuleList( [ Conv2d( in_channels=1, out_channels=32, kernel_size=(kernel_size, 1), stride=(stride, 1), padding="same", skip_transpose=True, weight_norm=True, ), Conv2d( in_channels=32, out_channels=128, kernel_size=(kernel_size, 1), stride=(stride, 1), padding="same", skip_transpose=True, weight_norm=True, ), Conv2d( in_channels=128, out_channels=512, kernel_size=(kernel_size, 1), stride=(stride, 1), padding="same", skip_transpose=True, weight_norm=True, ), Conv2d( in_channels=512, out_channels=1024, kernel_size=(kernel_size, 1), stride=(stride, 1), padding="same", skip_transpose=True, weight_norm=True, ), Conv2d( in_channels=1024, out_channels=1024, kernel_size=(kernel_size, 1), stride=1, padding="same", skip_transpose=True, weight_norm=True, ), ] ) self.conv_post = Conv2d( in_channels=1024, out_channels=1, kernel_size=(3, 1), stride=1, padding="same", skip_transpose=True, weight_norm=True, )
[docs] def forward(self, x): """ Arguments --------- x : torch.Tensor (batch, 1, time) input waveform. """ feat = [] # 1d to 2d b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for l in self.convs: x = l(x) x = F.leaky_relu(x, LRELU_SLOPE) feat.append(x) x = self.conv_post(x) feat.append(x) x = torch.flatten(x, 1, -1) return x, feat
[docs] class MultiPeriodDiscriminator(torch.nn.Module): """HiFiGAN Multi-Period Discriminator (MPD) Wrapper for the `PeriodDiscriminator` to apply it in different periods. Periods are suggested to be prime numbers to reduce the overlap between each discriminator. """ def __init__(self): super().__init__() self.discriminators = nn.ModuleList( [ DiscriminatorP(2), DiscriminatorP(3), DiscriminatorP(5), DiscriminatorP(7), DiscriminatorP(11), ] )
[docs] def forward(self, x): """Returns Multi-Period Discriminator scores and features Arguments --------- x : torch.Tensor (batch, 1, time) input waveform. """ scores = [] feats = [] for _, d in enumerate(self.discriminators): score, feat = d(x) scores.append(score) feats.append(feat) return scores, feats
[docs] class DiscriminatorS(torch.nn.Module): """HiFiGAN Scale Discriminator. It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper. SpeechBrain CNN wrappers are not used here beacause spectral_norm is not often used Arguments --------- use_spectral_norm : bool if `True` switch to spectral norm instead of weight norm. """ def __init__(self, use_spectral_norm=False): super().__init__() norm_f = ( nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm ) self.convs = nn.ModuleList( [ norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), ] ) self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
[docs] def forward(self, x): """ Arguments --------- x : torch.Tensor (batch, 1, time) input waveform. """ feat = [] for l in self.convs: x = l(x) x = F.leaky_relu(x, LRELU_SLOPE) feat.append(x) x = self.conv_post(x) feat.append(x) x = torch.flatten(x, 1, -1) return x, feat
[docs] class MultiScaleDiscriminator(torch.nn.Module): """HiFiGAN Multi-Scale Discriminator. Similar to MultiScaleMelganDiscriminator but specially tailored for HiFiGAN as in the paper. """ def __init__(self): super().__init__() self.discriminators = nn.ModuleList( [ DiscriminatorS(use_spectral_norm=True), DiscriminatorS(), DiscriminatorS(), ] ) self.meanpools = nn.ModuleList( [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)] )
[docs] def forward(self, x): """ Arguments --------- x : torch.Tensor (batch, 1, time) input waveform. """ scores = [] feats = [] for i, d in enumerate(self.discriminators): if i != 0: x = self.meanpools[i - 1](x) score, feat = d(x) scores.append(score) feats.append(feat) return scores, feats
[docs] class HifiganDiscriminator(nn.Module): """HiFiGAN discriminator wrapping MPD and MSD. Example ------- >>> inp_tensor = torch.rand([4, 1, 8192]) >>> hifigan_discriminator= HifiganDiscriminator() >>> scores, feats = hifigan_discriminator(inp_tensor) >>> len(scores) 8 >>> len(feats) 8 """ def __init__(self): super().__init__() self.mpd = MultiPeriodDiscriminator() self.msd = MultiScaleDiscriminator()
[docs] def forward(self, x): """Returns list of list of features from each layer of each discriminator. Arguments --------- x : torch.Tensor input waveform. """ scores, feats = self.mpd(x) scores_, feats_ = self.msd(x) return scores + scores_, feats + feats_
################################# # GENERATOR LOSSES #################################
[docs] def stft(x, n_fft, hop_length, win_length, window_fn="hann_window"): """computes the Fourier transform of short overlapping windows of the input""" o = torch.stft(x.squeeze(1), n_fft, hop_length, win_length,) M = o[:, :, :, 0] P = o[:, :, :, 1] S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) return S
[docs] class STFTLoss(nn.Module): """STFT loss. Input generate and real waveforms are converted to spectrograms compared with L1 and Spectral convergence losses. It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf Arguments --------- n_fft : int size of Fourier transform. hop_length : int the distance between neighboring sliding window frames. win_length : int the size of window frame and STFT filter. """ def __init__(self, n_fft, hop_length, win_length): super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length
[docs] def forward(self, y_hat, y): """Returns magnitude loss and spectral convergence loss Arguments --------- y_hat : torch.tensor generated waveform tensor y : torch.tensor real waveform tensor """ y_hat_M = stft(y_hat, self.n_fft, self.hop_length, self.win_length) y_M = stft(y, self.n_fft, self.hop_length, self.win_length) # magnitude loss loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) # spectral convergence loss loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") return loss_mag, loss_sc
[docs] class MultiScaleSTFTLoss(torch.nn.Module): """Multi-scale STFT loss. Input generate and real waveforms are converted to spectrograms compared with L1 and Spectral convergence losses. It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" def __init__( self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240), ): super().__init__() self.loss_funcs = torch.nn.ModuleList() for n_fft, hop_length, win_length in zip( n_ffts, hop_lengths, win_lengths ): self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
[docs] def forward(self, y_hat, y): """Returns multi-scale magnitude loss and spectral convergence loss Arguments --------- y_hat : torch.tensor generated waveform tensor y : torch.tensor real waveform tensor """ N = len(self.loss_funcs) loss_sc = 0 loss_mag = 0 for f in self.loss_funcs: lm, lsc = f(y_hat, y) loss_mag += lm loss_sc += lsc loss_sc /= N loss_mag /= N return loss_mag, loss_sc
[docs] class L1SpecLoss(nn.Module): """L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf Note : L1 loss helps leaning details compared with L2 loss Arguments --------- sample_rate : int Sample rate of audio signal. hop_length : int Length of hop between STFT windows. win_length : int Window size. n_fft : int Size of FFT. n_mels : int Number of mel filterbanks. f_min : float Minimum frequency. f_max : float Maximum frequency. power : float Exponent for the magnitude spectrogram. normalized : bool Whether to normalize by magnitude after stft. norm : str or None If "slaney", divide the triangular mel weights by the width of the mel band mel_scale : str Scale to use: "htk" or "slaney". compression : bool whether to do dynamic range compression """ def __init__( self, sample_rate=22050, hop_length=256, win_length=24, n_mel_channels=80, n_fft=1024, n_stft=1024 // 2 + 1, mel_fmin=0.0, mel_fmax=8000.0, mel_normalized=False, power=1.0, norm="slaney", mel_scale="slaney", dynamic_range_compression=True, ): super().__init__() self.sample_rate = sample_rate self.hop_length = hop_length self.win_length = win_length self.n_mel_channels = n_mel_channels self.n_fft = n_fft self.n_stft = n_fft // 2 + 1 self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax self.mel_normalized = mel_normalized self.power = power self.norm = norm self.mel_scale = mel_scale self.dynamic_range_compression = dynamic_range_compression
[docs] def forward(self, y_hat, y): """Returns L1 Loss over Spectrograms Arguments --------- y_hat : torch.tensor generated waveform tensor y : torch.tensor real waveform tensor """ y_hat_M = mel_spectogram( self.sample_rate, self.hop_length, self.win_length, self.n_fft, self.n_mel_channels, self.mel_fmin, self.mel_fmax, self.power, self.mel_normalized, self.norm, self.mel_scale, self.dynamic_range_compression, y_hat, ) # y_M = mel_spectogram(self.mel_params, y) y_M = mel_spectogram( self.sample_rate, self.hop_length, self.win_length, self.n_fft, self.n_mel_channels, self.mel_fmin, self.mel_fmax, self.power, self.mel_normalized, self.norm, self.mel_scale, self.dynamic_range_compression, y, ) # magnitude loss # loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) loss_mag = F.l1_loss(y_M, y_hat_M) return loss_mag
[docs] class MSEGLoss(nn.Module): """Mean Squared Generator Loss The generator is trained to fake the discriminator by updating the sample quality to be classified to a value almost equal to 1. """
[docs] def forward(self, score_fake): """Returns Generator GAN loss Arguments --------- score_fake : list discriminator scores of generated waveforms D(G(s)) """ loss_fake = F.mse_loss( score_fake, score_fake.new_ones(score_fake.shape) ) return loss_fake
[docs] class MelganFeatureLoss(nn.Module): """Calculates the feature matching loss, which is a learned similarity metric measured by the difference in features of the discriminator between a ground truth sample and a generated sample (Larsen et al., 2016, Kumar et al., 2019). """ def __init__(self,): super().__init__() self.loss_func = nn.L1Loss() # pylint: disable=no-self-use
[docs] def forward(self, fake_feats, real_feats): """Returns feature matching loss Arguments --------- fake_feats : list discriminator features of generated waveforms real_feats : list discriminator features of groundtruth waveforms """ loss_feats = 0 num_feats = 0 for idx, _ in enumerate(fake_feats): for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]): loss_feats += self.loss_func(fake_feat, real_feat) num_feats += 1 loss_feats = loss_feats / num_feats return loss_feats
################################## # DISCRIMINATOR LOSSES ##################################
[docs] class MSEDLoss(nn.Module): """Mean Squared Discriminator Loss The discriminator is trained to classify ground truth samples to 1, and the samples synthesized from the generator to 0. """ def __init__(self,): super().__init__() self.loss_func = nn.MSELoss()
[docs] def forward(self, score_fake, score_real): """Returns Discriminator GAN losses Arguments --------- score_fake : list discriminator scores of generated waveforms score_real : list discriminator scores of groundtruth waveforms """ loss_real = self.loss_func( score_real, score_real.new_ones(score_real.shape) ) loss_fake = self.loss_func( score_fake, score_fake.new_zeros(score_fake.shape) ) loss_d = loss_real + loss_fake return loss_d, loss_real, loss_fake
##################################### # LOSS WRAPPERS ##################################### def _apply_G_adv_loss(scores_fake, loss_func): """Compute Generator adversarial loss function and normalize values Arguments --------- scores_fake : list discriminator scores of generated waveforms loss_func : object object of target generator loss """ adv_loss = 0 if isinstance(scores_fake, list): for score_fake in scores_fake: fake_loss = loss_func(score_fake) adv_loss += fake_loss # adv_loss /= len(scores_fake) else: fake_loss = loss_func(scores_fake) adv_loss = fake_loss return adv_loss def _apply_D_loss(scores_fake, scores_real, loss_func): """Compute Discriminator losses and normalize loss values Arguments --------- scores_fake : list discriminator scores of generated waveforms scores_real : list discriminator scores of groundtruth waveforms loss_func : object object of target discriminator loss """ loss = 0 real_loss = 0 fake_loss = 0 if isinstance(scores_fake, list): # multi-scale loss for score_fake, score_real in zip(scores_fake, scores_real): total_loss, real_loss, fake_loss = loss_func( score_fake=score_fake, score_real=score_real ) loss += total_loss real_loss += real_loss fake_loss += fake_loss # normalize loss values with number of scales (discriminators) # loss /= len(scores_fake) # real_loss /= len(scores_real) # fake_loss /= len(scores_fake) else: # single scale loss total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) loss = total_loss return loss, real_loss, fake_loss ################################## # MODEL LOSSES ##################################
[docs] class GeneratorLoss(nn.Module): """Creates a summary of generator losses and applies weights for different losses Arguments --------- stft_loss : object object of stft loss stft_loss_weight : float weight of STFT loss mseg_loss : object object of mseg loss mseg_loss_weight : float weight of mseg loss feat_match_loss : object object of feature match loss feat_match_loss_weight : float weight of feature match loss l1_spec_loss : object object of L1 spectrogram loss l1_spec_loss_weight : float weight of L1 spectrogram loss """ def __init__( self, stft_loss=None, stft_loss_weight=0, mseg_loss=None, mseg_loss_weight=0, feat_match_loss=None, feat_match_loss_weight=0, l1_spec_loss=None, l1_spec_loss_weight=0, mseg_dur_loss=None, mseg_dur_loss_weight=0, ): super().__init__() self.stft_loss = stft_loss self.stft_loss_weight = stft_loss_weight self.mseg_loss = mseg_loss self.mseg_loss_weight = mseg_loss_weight self.feat_match_loss = feat_match_loss self.feat_match_loss_weight = feat_match_loss_weight self.l1_spec_loss = l1_spec_loss self.l1_spec_loss_weight = l1_spec_loss_weight self.mseg_dur_loss = mseg_dur_loss self.mseg_dur_loss_weight = mseg_dur_loss_weight
[docs] def forward( self, stage, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, log_dur_pred=None, log_dur=None, ): """Returns a dictionary of generator losses and applies weights Arguments --------- y_hat : torch.tensor generated waveform tensor y : torch.tensor real waveform tensor scores_fake : list discriminator scores of generated waveforms feats_fake : list discriminator features of generated waveforms feats_real : list discriminator features of groundtruth waveforms """ gen_loss = 0 adv_loss = 0 dur_loss = 0 loss = {} # STFT Loss if self.stft_loss: stft_loss_mg, stft_loss_sc = self.stft_loss( y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1) ) loss["G_stft_loss_mg"] = stft_loss_mg loss["G_stft_loss_sc"] = stft_loss_sc gen_loss = gen_loss + self.stft_loss_weight * ( stft_loss_mg + stft_loss_sc ) # L1 Spec loss if self.l1_spec_loss: l1_spec_loss = self.l1_spec_loss(y_hat, y) loss["G_l1_spec_loss"] = l1_spec_loss gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss # multiscale MSE adversarial loss if self.mseg_loss and scores_fake is not None: mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mseg_loss) loss["G_mse_fake_loss"] = mse_fake_loss adv_loss = adv_loss + self.mseg_loss_weight * mse_fake_loss # Feature Matching Loss if self.feat_match_loss and feats_fake is not None: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) loss["G_feat_match_loss"] = feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss # Duration loss if self.mseg_dur_loss and stage == sb.Stage.TRAIN: dur_loss = F.mse_loss(log_dur_pred, log_dur, reduction="mean") loss["G_dur_loss"] = dur_loss dur_loss *= self.mseg_dur_loss_weight loss["G_loss"] = gen_loss + adv_loss + dur_loss loss["G_gen_loss"] = gen_loss loss["G_adv_loss"] = adv_loss return loss
[docs] class DiscriminatorLoss(nn.Module): """Creates a summary of discriminator losses Arguments --------- msed_loss : object object of MSE discriminator loss """ def __init__(self, msed_loss=None): super().__init__() self.msed_loss = msed_loss
[docs] def forward(self, scores_fake, scores_real): """Returns a dictionary of discriminator losses Arguments --------- scores_fake : list discriminator scores of generated waveforms scores_real : list discriminator scores of groundtruth waveforms """ disc_loss = 0 loss = {} if self.msed_loss: mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( scores_fake=scores_fake, scores_real=scores_real, loss_func=self.msed_loss, ) loss["D_mse_gan_loss"] = mse_D_loss loss["D_mse_gan_real_loss"] = mse_D_real_loss loss["D_mse_gan_fake_loss"] = mse_D_fake_loss disc_loss += mse_D_loss loss["D_loss"] = disc_loss return loss