speechbrain.lobes.models.EnhanceResnet module

Wide ResNet for Speech Enhancement.

Author
  • Peter Plantinga 2022

Summary

Classes:

ConvBlock

Convolution block, including squeeze-and-excitation.

EnhanceResnet

Model for enhancement based on Wide ResNet.

SEblock

Squeeze-and-excitation block.

Reference

class speechbrain.lobes.models.EnhanceResnet.EnhanceResnet(n_fft=512, win_length=32, hop_length=16, sample_rate=16000, channel_counts=[128, 128, 256, 256, 512, 512], dense_count=2, dense_nodes=1024, activation=GELU(), normalization=<class 'speechbrain.nnet.normalization.BatchNorm2d'>, dropout=0.1, mask_weight=0.99)[source]

Bases: Module

Model for enhancement based on Wide ResNet.

Full model description at: https://arxiv.org/pdf/2112.06068.pdf

Parameters
  • n_fft (int) – Number of points in the fourier transform, see speechbrain.processing.features.STFT

  • win_length (int) – Length of stft window in ms, see speechbrain.processing.features.STFT

  • hop_length (int) – Time between windows in ms, see speechbrain.processing.features.STFT

  • sample_rate (int) – Number of samples per second of input audio.

  • channel_counts (list of ints) – Number of output channels in each CNN block. Determines number of blocks.

  • dense_count (int) – Number of dense layers.

  • dense_nodes (int) – Number of nodes in the dense layers.

  • activation (function) – Function to apply before convolution layers.

  • normalization (class) – Name of class to use for constructing norm layers.

  • dropout (float) – Portion of layer outputs to drop during training (between 0 and 1).

  • mask_weight (float) – Amount of weight to give mask. 0 - no masking, 1 - full masking.

Example

>>> inputs = torch.rand([10, 16000])
>>> model = EnhanceResnet()
>>> outputs, feats = model(inputs)
>>> outputs.shape
torch.Size([10, 15872])
>>> feats.shape
torch.Size([10, 63, 257])
forward(x)[source]

Processes the input tensor and outputs the enhanced speech.

extract_feats(x)[source]

Takes the stft output and produces features for computation.

training: bool
class speechbrain.lobes.models.EnhanceResnet.ConvBlock(input_shape, channels, activation=GELU(), normalization=<class 'speechbrain.nnet.normalization.LayerNorm'>, dropout=0.1)[source]

Bases: Module

Convolution block, including squeeze-and-excitation.

Parameters
  • input_shape (tuple of ints) – The expected size of the inputs.

  • channels (int) – Number of output channels.

  • activation (function) – Function applied before each block.

  • normalization (class) – Name of a class to use for constructing norm layers.

  • dropout (float) – Portion of block outputs to drop during training.

Example

>>> inputs = torch.rand([10, 20, 30, 128])
>>> block = ConvBlock(input_shape=inputs.shape, channels=256)
>>> outputs = block(inputs)
>>> outputs.shape
torch.Size([10, 20, 15, 256])
forward(x)[source]

Processes the input tensor with a convolutional block.

training: bool
class speechbrain.lobes.models.EnhanceResnet.SEblock(input_size)[source]

Bases: Module

Squeeze-and-excitation block.

Defined: https://arxiv.org/abs/1709.01507

Parameters

input_size (tuple of ints) – Expected size of the input tensor

Example

>>> inputs = torch.rand([10, 20, 30, 256])
>>> se_block = SEblock(input_size=inputs.shape[-1])
>>> outputs = se_block(inputs)
>>> outputs.shape
torch.Size([10, 1, 1, 256])
forward(x)[source]

Processes the input tensor with a speech enhancement block.

training: bool