speechbrain.lobes.models.transformer.Branchformer module

Branchformer implementation.

Ref: “Branchformer: Parallel MLP-Attention Architectures to Capture Local and Global Context for Speech Recognition and Understanding”

Source: Some parts of the code may be adapted from ESPNet.

Authors * Titouan Parcollet 2023

Summary

Classes:

BranchformerEncoder

This class implements the Branchformer encoder.

BranchformerEncoderLayer

This is an implementation of Branchformer encoder layer.

ConvolutionBranch

This is an implementation of the convolution branch in Branchformer.

Reference

class speechbrain.lobes.models.transformer.Branchformer.ConvolutionBranch(input_size, linear_units=3072, kernel_size=31, activation=<class 'torch.nn.modules.activation.GELU'>, gate_activation=<class 'torch.nn.modules.linear.Identity'>, dropout=0.0, use_linear_after_conv=False)[source]

Bases: Module

This is an implementation of the convolution branch in Branchformer.

The default structure is: LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout

Parameters:
  • input_size (int) – The expected size of the feature (channel) dimension.

  • linear_units (int, optional) – Number of neurons in the hidden linear units.

  • kernel_size (int, optional) – Kernel size of non-bottleneck convolutional layer.

  • activation (torch.nn.Module, optional) – Activation function used after pre projection.

  • gate_activation (torch.nn.Module, optional) – Activation function used at the gate of the CSGU module.

  • dropout (float, optional) – Dropout rate.

  • use_linear_after_conv (bool, optional) – If True, will apply a linear transformation of size input_size//2

Example

>>> x = torch.rand((8, 60, 512))
>>> net = ConvolutionBranch(512, 1024)
>>> output = net(x)
>>> output.shape
torch.Size([8, 60, 512])
forward(x)[source]
Parameters:

x (torch.Tensor -> (B, T, D)) –

training: bool
class speechbrain.lobes.models.transformer.Branchformer.BranchformerEncoderLayer(d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=<class 'torch.nn.modules.activation.GELU'>, dropout=0.0, attention_type='RelPosMHAXL', csgu_linear_units=3072, gate_activation=<class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv=False)[source]

Bases: Module

This is an implementation of Branchformer encoder layer.

Parameters:
  • d_model (int) – The expected size of the input embedding.

  • nhead (int) – Number of attention heads.

  • kernel_size (int, optional) – Kernel size of convolution model.

  • kdim (int, optional) – Dimension of the key.

  • vdim (int, optional) – Dimension of the value.

  • activation (torch.nn.Module) – Activation function used in each Conformer layer.

  • dropout (int, optional) – Dropout for the encoder.

  • attention_type (str, optional) – type of attention layer, e.g. regulaMHA for regular MultiHeadAttention.

  • csgu_linear_units (int, optional) – Number of neurons in the hidden linear units of the CSGU Module.

  • gate_activation (torch.nn.Module, optional) – Activation function used at the gate of the CSGU module.

  • use_linear_after_conv (bool, optional) – If True, will apply a linear transformation of size input_size//2

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_embs = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3)
>>> output = net(x, pos_embs=pos_embs)
>>> output[0].shape
torch.Size([8, 60, 512])
forward(x, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None)[source]
Parameters:
  • x (torch.Tensor) – The sequence to the encoder layer.

  • src_mask (torch.Tensor, optional) – The mask for the src sequence.

  • src_key_padding_mask (torch.Tensor, optional) – The mask for the src keys per batch.

  • pos_embs (torch.Tensor, torch.nn.Module, optional) – Module or tensor containing the input sequence positional embeddings

training: bool
class speechbrain.lobes.models.transformer.Branchformer.BranchformerEncoder(num_layers, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=<class 'torch.nn.modules.activation.GELU'>, dropout=0.0, attention_type='RelPosMHAXL', csgu_linear_units=3072, gate_activation=<class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv=False)[source]

Bases: Module

This class implements the Branchformer encoder.

Parameters:
  • num_layers (int) – Number of layers.

  • d_model (int) – Embedding dimension size.

  • nhead (int) – Number of attention heads.

  • kernel_size (int, optional) – Kernel size of convolution model.

  • kdim (int, optional) – Dimension of the key.

  • vdim (int, optional) – Dimension of the value.

  • activation (torch.nn.Module) – Activation function used in each Confomer layer.

  • dropout (int, optional) – Dropout for the encoder.

  • attention_type (str, optional) – type of attention layer, e.g. regulaMHA for regular MultiHeadAttention.

  • csgu_linear_units (int, optional) – Number of neurons in the hidden linear units of the CSGU Module.

  • gate_activation (torch.nn.Module, optional) – Activation function used at the gate of the CSGU module.

  • use_linear_after_conv (bool, optional) – If True, will apply a linear transformation of size input_size//2.

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_emb = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoder(1, 512, 8)
>>> output, _ = net(x, pos_embs=pos_emb)
>>> output.shape
torch.Size([8, 60, 512])
forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None)[source]
Parameters:
  • src (torch.Tensor) – The sequence to the encoder layer.

  • src_mask (torch.Tensor, optional) – The mask for the src sequence.

  • src_key_padding_mask (torch.Tensor, optional) – The mask for the src keys per batch.

  • pos_embs (torch.Tensor, torch.nn.Module,) – Module or tensor containing the input sequence positional embeddings If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) where S is the sequence length, and E is the embedding dimension.

training: bool