speechbrain.lobes.models.resepformer module

Library for the Resource-Efficient Sepformer.

Authors
  • Cem Subakan 2022

Summary

Classes:

MemLSTM

the Mem-LSTM of SkiM --

ResourceEfficientSeparationPipeline

Resource Efficient Separation Pipeline Used for RE-SepFormer and SkiM

ResourceEfficientSeparator

Resource Efficient Source Separator This is the class that implements RE-SepFormer

SBRNNBlock

RNNBlock with output layer.

SBTransformerBlock_wnormandskip

A wrapper for the SpeechBrain implementation of the transformer encoder.

SegLSTM

the Segment-LSTM of SkiM

Reference

class speechbrain.lobes.models.resepformer.MemLSTM(hidden_size, dropout=0.0, bidirectional=False, mem_type='hc', norm_type='cln')[source]

Bases: Module

the Mem-LSTM of SkiM –

Note: This is taken from the SkiM implementation in ESPNet toolkit and modified for compatibility with SpeechBrain.

Parameters:
  • hidden_size (int) – Dimension of the hidden state.

  • dropout (float) – dropout ratio. Default is 0.

  • bidirectional (bool) – Whether the LSTM layers are bidirectional. Default is False.

  • mem_type (str) – β€˜hc’, β€˜h’, β€˜c’, or β€˜id’ This controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. In β€˜id’ mode, both the hidden and cell states will be identically returned.

  • norm_type (str) – β€˜gln’, β€˜cln’ This selects the type of normalization cln is for causal implementation

Example

>>> x = (torch.randn(1, 5, 64), torch.randn(1, 5, 64))
>>> block = MemLSTM(64)
>>> x = block(x, 5)
>>> x[0].shape
torch.Size([1, 5, 64])
forward(hc, S)[source]

The forward function for the memory RNN

Parameters:
  • hc (tuple) –

    (h, c), tuple of hidden and cell states from SegLSTM shape of h and c: (d, B*S, H)

    where d is the number of directions

    B is the batchsize S is the number chunks H is the latent dimensionality

  • S (int) – S is the number of chunks

Returns:

ret_val – The output of memory RNN

Return type:

torch.Tensor

class speechbrain.lobes.models.resepformer.SegLSTM(input_size, hidden_size, dropout=0.0, bidirectional=False, norm_type='cLN')[source]

Bases: Module

the Segment-LSTM of SkiM

Note: This is taken from the SkiM implementation in ESPNet toolkit and modified for compatibility with SpeechBrain.

Parameters:
  • input_size (int,) – dimension of the input feature. The input should have shape (batch, seq_len, input_size).

  • hidden_size (int,) – dimension of the hidden state.

  • dropout (float,) – dropout ratio. Default is 0.

  • bidirectional (bool,) – whether the LSTM layers are bidirectional. Default is False.

  • norm_type (str) – One of gln, cln. This selects the type of normalization cln is for causal implementation.

Example

>>> x = torch.randn(3, 20, 64)
>>> hc = None
>>> seglstm = SegLSTM(64, 64)
>>> y = seglstm(x, hc)
>>> y[0].shape
torch.Size([3, 20, 64])
forward(input, hc)[source]

The forward function of the Segment LSTM

Parameters:
  • input (torch.Tensor) –

    shape [B*S, T, H] where B is the batchsize

    S is the number of chunks T is the chunks size H is the latent dimensionality

  • hc (tuple) –

    tuple of hidden and cell states from SegLSTM shape of h and c: (d, B*S, H)

    where d is the number of directions

    B is the batchsize S is the number chunks H is the latent dimensionality

Returns:

  • output (torch.Tensor) – Output of Segment LSTM

  • (h, c) (tuple) – Same as hc input

class speechbrain.lobes.models.resepformer.SBRNNBlock(input_size, hidden_channels, num_layers, outsize, rnn_type='LSTM', dropout=0, bidirectional=True)[source]

Bases: Module

RNNBlock with output layer.

Parameters:
  • input_size (int) – Dimensionality of the input features.

  • hidden_channels (int) – Dimensionality of the latent layer of the rnn.

  • num_layers (int) – Number of the rnn layers.

  • outsize (int) – Number of dimensions at the output of the linear layer

  • rnn_type (str) – Type of the the rnn cell.

  • dropout (float) – Dropout rate

  • bidirectional (bool) – If True, bidirectional.

Example

>>> x = torch.randn(10, 100, 64)
>>> rnn = SBRNNBlock(64, 100, 1, 128, bidirectional=True)
>>> x = rnn(x)
>>> x.shape
torch.Size([10, 100, 128])
forward(x)[source]

Returns the transformed output.

Parameters:

x (torch.Tensor) –

[B, L, N] where, B = Batchsize,

N = number of filters L = time points

Returns:

out – The transformed output.

Return type:

torch.Tensor

class speechbrain.lobes.models.resepformer.SBTransformerBlock_wnormandskip(num_layers, d_model, nhead, d_ffn=2048, input_shape=None, kdim=None, vdim=None, dropout=0.1, activation='relu', use_positional_encoding=False, norm_before=False, attention_type='regularMHA', causal=False, use_norm=True, use_skip=True, norm_type='gln')[source]

Bases: Module

A wrapper for the SpeechBrain implementation of the transformer encoder.

Parameters:
  • num_layers (int) – Number of layers.

  • d_model (int) – Dimensionality of the representation.

  • nhead (int) – Number of attention heads.

  • d_ffn (int) – Dimensionality of positional feed forward.

  • input_shape (tuple) – Shape of input.

  • kdim (int) – Dimension of the key (Optional).

  • vdim (int) – Dimension of the value (Optional).

  • dropout (float) – Dropout rate.

  • activation (str) – Activation function.

  • use_positional_encoding (bool) – If true we use a positional encoding.

  • norm_before (bool) – Use normalization before transformations.

  • attention_type (str) – Type of attention, default β€œregularMHA”

  • causal (bool) – Whether to mask future information, default False

  • use_norm (bool) – Whether to include norm in the block.

  • use_skip (bool) – Whether to add skip connections in the block.

  • norm_type (str) – One of β€œcln”, β€œgln”

Example

>>> x = torch.randn(10, 100, 64)
>>> block = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> x = block(x)
>>> x.shape
torch.Size([10, 100, 64])
forward(x)[source]

Returns the transformed output.

Parameters:

x (torch.Tensor) –

Tensor shape [B, L, N], where, B = Batchsize,

L = time points N = number of filters

Returns:

out – The transformed output.

Return type:

torch.Tensor

class speechbrain.lobes.models.resepformer.ResourceEfficientSeparationPipeline(input_size, hidden_size, output_size, dropout=0.0, num_blocks=2, segment_size=20, bidirectional=True, mem_type='av', norm_type='gln', seg_model=None, mem_model=None)[source]

Bases: Module

Resource Efficient Separation Pipeline Used for RE-SepFormer and SkiM

Note: This implementation is a generalization of the ESPNET implementation of SkiM

Parameters:
  • input_size (int) – Dimension of the input feature. Input shape should be (batch, length, input_size)

  • hidden_size (int) – Dimension of the hidden state.

  • output_size (int) – Dimension of the output size.

  • dropout (float) – Dropout ratio. Default is 0.

  • num_blocks (int) – Number of basic SkiM blocks

  • segment_size (int) – Segmentation size for splitting long features

  • bidirectional (bool) – Whether the RNN layers are bidirectional.

  • mem_type (str) – β€˜hc’, β€˜h’, β€˜c’, β€˜id’ or None. This controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. In β€˜id’ mode, both the hidden and cell states will be identically returned. When mem_type is None, the MemLSTM will be removed.

  • norm_type (str) – One of gln or cln cln is for causal implementation.

  • seg_model (class) – The model that processes the within segment elements

  • mem_model (class) – The memory model that ensures continuity between the segments

Example

>>> x = torch.randn(10, 100, 64)
>>> seg_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> mem_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> resepf_pipeline = ResourceEfficientSeparationPipeline(64, 64, 128, seg_model=seg_mdl, mem_model=mem_mdl)
>>> out = resepf_pipeline.forward(x)
>>> out.shape
torch.Size([10, 100, 128])
forward(input)[source]

The forward function of the ResourceEfficientSeparationPipeline

This takes in a tensor of size [B, (S*K), D]

Parameters:

input (torch.Tensor) –

Tensor shape [B, (S*K), D], where, B = Batchsize,

S = Number of chunks K = Chunksize D = number of features

Returns:

output – The separated tensor.

Return type:

torch.Tensor

class speechbrain.lobes.models.resepformer.ResourceEfficientSeparator(input_dim: int, causal: bool = True, num_spk: int = 2, nonlinear: str = 'relu', layer: int = 3, unit: int = 512, segment_size: int = 20, dropout: float = 0.0, mem_type: str = 'hc', seg_model=None, mem_model=None)[source]

Bases: Module

Resource Efficient Source Separator This is the class that implements RE-SepFormer

Parameters:
  • input_dim (int) – Input feature dimension

  • causal (bool) – Whether the system is causal.

  • num_spk (int) – Number of target speakers.

  • nonlinear (class) – the nonlinear function for mask estimation, select from β€˜relu’, β€˜tanh’, β€˜sigmoid’

  • layer (int) – number of blocks. Default is 2 for RE-SepFormer.

  • unit (int) – Dimensionality of the hidden state.

  • segment_size (int) – Chunk size for splitting long features

  • dropout (float) – dropout ratio. Default is 0.

  • mem_type (str) – β€˜hc’, β€˜h’, β€˜c’, β€˜id’, β€˜av’ or None. This controls whether a memory representation will be used to ensure continuity between segments. In β€˜av’ mode, the summary state is is calculated by simply averaging over the time dimension of each segment In β€˜id’ mode, both the hidden and cell states will be identically returned. When mem_type is None, the memory model will be removed.

  • seg_model (class) – The model that processes the within segment elements

  • mem_model (class) – The memory model that ensures continuity between the segments

Example

>>> x = torch.randn(10, 64, 100)
>>> seg_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> mem_mdl = SBTransformerBlock_wnormandskip(1, 64, 8)
>>> resepformer = ResourceEfficientSeparator(64, num_spk=3, mem_type='av', seg_model=seg_mdl, mem_model=mem_mdl)
>>> out = resepformer.forward(x)
>>> out.shape
torch.Size([3, 10, 64, 100])
forward(inpt: Tensor)[source]

Forward

Parameters:

inpt (torch.Tensor) – Encoded feature [B, T, N]

Returns:

mask_tensor

Return type:

torch.Tensor