speechbrain.lobes.models.MetricGAN_U module

Generator and discriminator used in MetricGAN-U

Authors: * Szu-Wei Fu 2020

Summary

Classes:

EnhancementGenerator

Simple LSTM for enhancement with custom initialization.

MetricDiscriminator

Metric estimator for enhancement training.

Functions:

xavier_init_layer

Create a layer with spectral norm, xavier uniform init and zero bias

Reference

speechbrain.lobes.models.MetricGAN_U.xavier_init_layer(in_size, out_size=None, spec_norm=True, layer_type=<class 'torch.nn.modules.linear.Linear'>, **kwargs)[source]

Create a layer with spectral norm, xavier uniform init and zero bias

class speechbrain.lobes.models.MetricGAN_U.EnhancementGenerator(input_size=257, hidden_size=200, num_layers=2, lin_dim=300, dropout=0)[source]

Bases: torch.nn.modules.module.Module

Simple LSTM for enhancement with custom initialization.

Parameters
  • input_size (int) – Size of the input tensor’s last dimension.

  • hidden_size (int) – Number of neurons to use in the LSTM layers.

  • num_layers (int) – Number of layers to use in the LSTM.

  • lin_dim (int) – Number of neurons in the last two linear layers.

  • dropout (int) – Fraction of neurons to drop during training.

Example

>>> inputs = torch.rand([10, 100, 40])
>>> model = EnhancementGenerator(input_size=40, hidden_size=50)
>>> outputs = model(inputs, lengths=torch.ones([10]))
>>> outputs.shape
torch.Size([10, 100, 40])
blstm

Use orthogonal init for recurrent layers, xavier uniform for input layers Bias is 0

forward(x, lengths)[source]
training: bool
class speechbrain.lobes.models.MetricGAN_U.MetricDiscriminator(kernel_size=(5, 5), base_channels=15, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, lin_dim1=50, lin_dim2=10)[source]

Bases: torch.nn.modules.module.Module

Metric estimator for enhancement training.

Consists of:
  • four 2d conv layers

  • channel averaging

  • three linear layers

Parameters
  • kernel_size (tuple) – The dimensions of the 2-d kernel used for convolution.

  • base_channels (int) – Number of channels used in each conv layer.

  • lin_dim1 (int) – Dimensionality of the first linear layer.

  • lin_dim2 (int) – Dimensionality of the second linear layer.

Example

>>> inputs = torch.rand([1, 1, 100, 257])
>>> model = MetricDiscriminator()
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([1, 1])
forward(x)[source]
training: bool