speechbrain.lobes.models.MetricGAN module

Generator and discriminator used in MetricGAN

Authors: * Szu-Wei Fu 2020

Summary

Classes:

EnhancementGenerator

Simple LSTM for enhancement with custom initialization.

Learnable_sigmoid

MetricDiscriminator

Metric estimator for enhancement training.

Functions:

shifted_sigmoid

xavier_init_layer

Create a layer with spectral norm, xavier uniform init and zero bias

Reference

speechbrain.lobes.models.MetricGAN.xavier_init_layer(in_size, out_size=None, spec_norm=True, layer_type=<class 'torch.nn.modules.linear.Linear'>, **kwargs)[source]

Create a layer with spectral norm, xavier uniform init and zero bias

speechbrain.lobes.models.MetricGAN.shifted_sigmoid(x)[source]
class speechbrain.lobes.models.MetricGAN.Learnable_sigmoid(in_features=257)[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]
training: bool
class speechbrain.lobes.models.MetricGAN.EnhancementGenerator(input_size=257, hidden_size=200, num_layers=2, dropout=0)[source]

Bases: torch.nn.modules.module.Module

Simple LSTM for enhancement with custom initialization.

Parameters
  • input_size (int) – Size of the input tensor’s last dimension.

  • hidden_size (int) – Number of neurons to use in the LSTM layers.

  • num_layers (int) – Number of layers to use in the LSTM.

  • dropout (int) – Fraction of neurons to drop during training.

blstm

Use orthogonal init for recurrent layers, xavier uniform for input layers Bias is 0

forward(x, lengths)[source]
training: bool
class speechbrain.lobes.models.MetricGAN.MetricDiscriminator(kernel_size=(5, 5), base_channels=15, activation=<class 'torch.nn.modules.activation.LeakyReLU'>)[source]

Bases: torch.nn.modules.module.Module

Metric estimator for enhancement training.

Consists of:
  • four 2d conv layers

  • channel averaging

  • three linear layers

Parameters
  • kernel_size (tuple) – The dimensions of the 2-d kernel used for convolution.

  • base_channels (int) – Number of channels used in each conv layer.

forward(x)[source]
training: bool