speechbrain.lobes.models.segan_model module

This file contains two PyTorch modules which together consist of the SEGAN model architecture (based on the paper: Pascual et al. https://arxiv.org/pdf/1703.09452.pdf). Modification of the initialization parameters allows the change of the model described in the class project, such as turning the generator to a VAE, or removing the latent variable concatenation.

Loss functions for training SEGAN are also defined in this file.

Authors
  • Francis Carter 2021

Summary

Classes:

Discriminator

CNN discriminator of SEGAN

Generator

CNN Autoencoder model to clean speech signals.

Functions:

d1_loss

Calculates the loss of the discriminator when the inputs are clean

d2_loss

Calculates the loss of the discriminator when the inputs are not clean

g3_loss

Calculates the loss of the generator given the discriminator outputs

Reference

class speechbrain.lobes.models.segan_model.Generator(kernel_size, latent_vae, z_prob)[source]

Bases: torch.nn.modules.module.Module

CNN Autoencoder model to clean speech signals.

Parameters
  • kernel_size (int) – Size of the convolutional kernel.

  • latent_vae (bool) – Whether or not to convert the autoencoder to a vae

  • z_prob (bool) – Whether to remove the latent variable concatenation. Is only applicable if latent_vae is False

forward(x)[source]

Forward pass through autoencoder

training: bool
class speechbrain.lobes.models.segan_model.Discriminator(kernel_size)[source]

Bases: torch.nn.modules.module.Module

CNN discriminator of SEGAN

Parameters

kernel_size (int) – Size of the convolutional kernel.

forward(x)[source]

forward pass through the discriminator

training: bool
speechbrain.lobes.models.segan_model.d1_loss(d_outputs, reduction='mean')[source]

Calculates the loss of the discriminator when the inputs are clean

speechbrain.lobes.models.segan_model.d2_loss(d_outputs, reduction='mean')[source]

Calculates the loss of the discriminator when the inputs are not clean

speechbrain.lobes.models.segan_model.g3_loss(d_outputs, predictions, targets, length, l1LossCoeff, klLossCoeff, z_mean=None, z_logvar=None, reduction='mean')[source]

Calculates the loss of the generator given the discriminator outputs