Source code for speechbrain.lobes.models.segan_model

"""
This file contains two PyTorch modules which together consist of the SEGAN model architecture
(based on the paper: Pascual et al. https://arxiv.org/pdf/1703.09452.pdf).
Modification of the initialization parameters allows the change of the model described in the class project,
such as turning the generator to a VAE, or removing the latent variable concatenation.

Loss functions for training SEGAN are also defined in this file.

Authors
 * Francis Carter 2021
"""

from math import floor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data


[docs] class Generator(torch.nn.Module): """CNN Autoencoder model to clean speech signals. Arguments --------- kernel_size : int Size of the convolutional kernel. latent_vae : bool Whether or not to convert the autoencoder to a vae z_prob : bool Whether to remove the latent variable concatenation. Is only applicable if latent_vae is False """ def __init__(self, kernel_size, latent_vae, z_prob): super().__init__() self.EncodeLayers = torch.nn.ModuleList() self.DecodeLayers = torch.nn.ModuleList() self.kernel_size = 5 self.latent_vae = latent_vae self.z_prob = z_prob EncoderChannels = [1, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] DecoderChannels = [ 2048, 1024, 512, 512, 256, 256, 128, 128, 64, 64, 32, 1, ] # Create encoder and decoder layers. for i in range(len(EncoderChannels) - 1): if i == len(EncoderChannels) - 2 and self.latent_vae: outs = EncoderChannels[i + 1] * 2 else: outs = EncoderChannels[i + 1] self.EncodeLayers.append( nn.Conv1d( in_channels=EncoderChannels[i], out_channels=outs, kernel_size=kernel_size, stride=2, padding=floor(kernel_size / 2), # same ) ) for i in range(len(DecoderChannels) - 1): if i == 0 and self.latent_vae: ins = EncoderChannels[-1 * (i + 1)] else: ins = EncoderChannels[-1 * (i + 1)] * 2 self.DecodeLayers.append( nn.ConvTranspose1d( in_channels=ins, out_channels=EncoderChannels[-1 * (i + 2)], kernel_size=kernel_size + 1, # adding one to kernel size makes the dimensions match stride=2, padding=floor(kernel_size / 2), # same ) )
[docs] def forward(self, x): """Forward pass through autoencoder""" # encode skips = [] x = x.permute(0, 2, 1) for i, layer in enumerate(self.EncodeLayers): x = layer(x) skips.append(x.clone()) if i == len(self.DecodeLayers) - 1: continue else: x = F.leaky_relu(x, negative_slope=0.3) # fuse z if self.latent_vae: z_mean, z_logvar = x.chunk(2, dim=1) x = z_mean + torch.exp(z_logvar / 2.0) * torch.randn_like( z_logvar, device=x.device ) # sampling from latent var probability distribution elif self.z_prob: z = torch.normal(torch.zeros_like(x), torch.ones_like(x)) x = torch.cat((x, z), 1) else: z = torch.zeros_like(x) x = torch.cat((x, z), 1) # decode for i, layer in enumerate(self.DecodeLayers): x = layer(x) if i == len(self.DecodeLayers) - 1: continue else: x = torch.cat((x, skips[-1 * (i + 2)]), 1) x = F.leaky_relu(x, negative_slope=0.3) x = x.permute(0, 2, 1) if self.latent_vae: return x, z_mean, z_logvar else: return x
[docs] class Discriminator(torch.nn.Module): """CNN discriminator of SEGAN Arguments --------- kernel_size : int Size of the convolutional kernel. """ def __init__(self, kernel_size): super().__init__() self.Layers = torch.nn.ModuleList() self.Norms = torch.nn.ModuleList() Channels = [2, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024, 1] # Create encoder and decoder layers. for i in range(len(Channels) - 1): if i != len(Channels) - 2: self.Layers.append( nn.Conv1d( in_channels=Channels[i], out_channels=Channels[i + 1], kernel_size=kernel_size, stride=2, padding=floor(kernel_size / 2), # same ) ) self.Norms.append( nn.BatchNorm1d( num_features=Channels[ i + 1 ] # not sure what the last dim should be here ) ) # output convolution else: self.Layers.append( nn.Conv1d( in_channels=Channels[i], out_channels=Channels[i + 1], kernel_size=1, stride=1, padding=0, # same ) ) self.Layers.append( nn.Linear( in_features=8, out_features=1, ) # Channels[i+1], )
[docs] def forward(self, x): """forward pass through the discriminator""" x = x.permute(0, 2, 1) # encode for i in range(len(self.Norms)): x = self.Layers[i](x) x = self.Norms[i](x) x = F.leaky_relu(x, negative_slope=0.3) # output x = self.Layers[-2](x) x = self.Layers[-1](x) # x = F.sigmoid(x) x = x.permute(0, 2, 1) return x # in logit format
[docs] def d1_loss(d_outputs, reduction="mean"): """Calculates the loss of the discriminator when the inputs are clean""" output = 0.5 * ((d_outputs - 1) ** 2) if reduction == "mean": return output.mean() elif reduction == "batch": return output.view(output.size(0), -1).mean(1)
[docs] def d2_loss(d_outputs, reduction="mean"): """Calculates the loss of the discriminator when the inputs are not clean""" output = 0.5 * ((d_outputs) ** 2) if reduction == "mean": return output.mean() elif reduction == "batch": return output.view(output.size(0), -1).mean(1)
[docs] def g3_loss( d_outputs, predictions, targets, length, l1LossCoeff, klLossCoeff, z_mean=None, z_logvar=None, reduction="mean", ): """Calculates the loss of the generator given the discriminator outputs""" discrimloss = 0.5 * ((d_outputs - 1) ** 2) l1norm = torch.nn.functional.l1_loss(predictions, targets, reduction="none") if not ( z_mean is None ): # This will determine if model is being trained as a vae ZERO = torch.zeros_like(z_mean) distq = torch.distributions.normal.Normal( z_mean, torch.exp(z_logvar) ** (1 / 2) ) distp = torch.distributions.normal.Normal( ZERO, torch.exp(ZERO) ** (1 / 2) ) kl = torch.distributions.kl.kl_divergence(distq, distp) kl = kl.sum(axis=1).sum(axis=1).mean() else: kl = 0 if reduction == "mean": return ( discrimloss.mean() + l1LossCoeff * l1norm.mean() + klLossCoeff * kl ) elif reduction == "batch": dloss = discrimloss.view(discrimloss.size(0), -1).mean(1) lloss = l1norm.view(l1norm.size(0), -1).mean(1) return dloss + l1LossCoeff * lloss + klLossCoeff * kl