speechbrain.lobes.models.transformer.Branchformer moduleο
Branchformer implementation.
Ref: βBranchformer: Parallel MLP-Attention Architectures to Capture Local and Global Context for Speech Recognition and Understandingβ
Source: Some parts of the code may be adapted from ESPNet.
Authors * Titouan Parcollet 2023
Summaryο
Classes:
This class implements the Branchformer encoder. |
|
This is an implementation of Branchformer encoder layer. |
|
This is an implementation of the convolution branch in Branchformer. |
Referenceο
- class speechbrain.lobes.models.transformer.Branchformer.ConvolutionBranch(input_size, linear_units=3072, kernel_size=31, activation=<class 'torch.nn.modules.activation.GELU'>, gate_activation=<class 'torch.nn.modules.linear.Identity'>, dropout=0.0, use_linear_after_conv=False)[source]ο
Bases:
Module
This is an implementation of the convolution branch in Branchformer.
The default structure is: LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout
- Parameters:
input_size (int) β The expected size of the feature (channel) dimension.
linear_units (int, optional) β Number of neurons in the hidden linear units.
kernel_size (int, optional) β Kernel size of non-bottleneck convolutional layer.
activation (torch.nn.Module, optional) β Activation function used after pre projection.
gate_activation (torch.nn.Module, optional) β Activation function used at the gate of the CSGU module.
dropout (float, optional) β Dropout rate.
use_linear_after_conv (bool, optional) β If True, will apply a linear transformation of size input_size//2
Example
>>> x = torch.rand((8, 60, 512)) >>> net = ConvolutionBranch(512, 1024) >>> output = net(x) >>> output.shape torch.Size([8, 60, 512])
- class speechbrain.lobes.models.transformer.Branchformer.BranchformerEncoderLayer(d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=<class 'torch.nn.modules.activation.GELU'>, dropout=0.0, attention_type='RelPosMHAXL', csgu_linear_units=3072, gate_activation=<class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv=False)[source]ο
Bases:
Module
This is an implementation of Branchformer encoder layer.
- Parameters:
d_model (int) β The expected size of the input embedding.
nhead (int) β Number of attention heads.
kernel_size (int, optional) β Kernel size of convolution model.
kdim (int, optional) β Dimension of the key.
vdim (int, optional) β Dimension of the value.
activation (torch.nn.Module) β Activation function used in each Conformer layer.
dropout (int, optional) β Dropout for the encoder.
attention_type (str, optional) β type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
csgu_linear_units (int, optional) β Number of neurons in the hidden linear units of the CSGU Module.
gate_activation (torch.nn.Module, optional) β Activation function used at the gate of the CSGU module.
use_linear_after_conv (bool, optional) β If True, will apply a linear transformation of size input_size//2
Example
>>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_embs = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3) >>> output = net(x, pos_embs=pos_embs) >>> output[0].shape torch.Size([8, 60, 512])
- forward(x, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None)[source]ο
- Parameters:
x (torch.Tensor) β The sequence to the encoder layer.
src_mask (torch.Tensor, optional) β The mask for the src sequence.
src_key_padding_mask (torch.Tensor, optional) β The mask for the src keys per batch.
pos_embs (torch.Tensor, torch.nn.Module, optional) β Module or tensor containing the input sequence positional embeddings
- class speechbrain.lobes.models.transformer.Branchformer.BranchformerEncoder(num_layers, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=<class 'torch.nn.modules.activation.GELU'>, dropout=0.0, attention_type='RelPosMHAXL', csgu_linear_units=3072, gate_activation=<class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv=False, output_hidden_states=False, layerdrop_prob=0.0)[source]ο
Bases:
Module
This class implements the Branchformer encoder.
- Parameters:
num_layers (int) β Number of layers.
d_model (int) β Embedding dimension size.
nhead (int) β Number of attention heads.
kernel_size (int, optional) β Kernel size of convolution model.
kdim (int, optional) β Dimension of the key.
vdim (int, optional) β Dimension of the value.
activation (torch.nn.Module) β Activation function used in each Confomer layer.
dropout (int, optional) β Dropout for the encoder.
attention_type (str, optional) β type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
csgu_linear_units (int, optional) β Number of neurons in the hidden linear units of the CSGU Module.
gate_activation (torch.nn.Module, optional) β Activation function used at the gate of the CSGU module.
use_linear_after_conv (bool, optional) β If True, will apply a linear transformation of size input_size//2.
output_hidden_states (bool, optional) β Whether the model should output the hidden states as a list of tensor.
layerdrop_prob (float) β The probability to drop an entire layer.
Example
>>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8) >>> output, _ = net(x, pos_embs=pos_emb) >>> output.shape torch.Size([8, 60, 512])
>>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8, output_hidden_states=True) >>> output, attn_list, hidden_list = net(x, pos_embs=pos_emb) >>> hidden_list[0].shape torch.Size([8, 60, 512]) >>> len(hidden_list) 2
- forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None, dynchunktrain_config=None)[source]ο
- Parameters:
src (torch.Tensor) β The sequence to the encoder layer.
src_mask (torch.Tensor, optional) β The mask for the src sequence.
src_key_padding_mask (torch.Tensor, optional) β The mask for the src keys per batch.
pos_embs (torch.Tensor, torch.nn.Module,) β Module or tensor containing the input sequence positional embeddings If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) where S is the sequence length, and E is the embedding dimension.
dynchunktrain_config (None) β This configuration is unsupported for this encoder.
- Returns:
output (torch.Tensor) β The output of the Conformer.
attention_lst (list) β The attention values.
hidden_state_lst (list, optional) β The output of the hidden layers of the encoder. Only works if output_hidden_states is set to true.