speechbrain.lobes.models.transformer.Transformer module

Transformer implementation in the SpeechBrain style. Authors * Jianyuan Zhong 2020 * Samuele Cornell 2021 * Shucong Zhang 2024

Summary

Classes:

NormalizedEmbedding

This class implements the normalized embedding layer for the transformer.

PositionalEncoding

This class implements the absolute sinusoidal positional encoding function.

TransformerDecoder

This class implements the Transformer decoder.

TransformerDecoderLayer

This class implements the self-attention decoder layer.

TransformerEncoder

This class implements the transformer encoder.

TransformerEncoderLayer

This is an implementation of self-attention encoder layer.

TransformerInterface

This is an interface for transformer model.

Functions:

get_key_padding_mask

Creates a binary mask to prevent attention to padded locations.

get_lookahead_mask

Creates a binary mask for each sequence which masks future frames.

get_mask_from_lengths

Creates a binary mask from sequence lengths

Reference

class speechbrain.lobes.models.transformer.Transformer.TransformerInterface(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation: type = <class 'torch.nn.modules.activation.ReLU'>, custom_src_module=None, custom_tgt_module=None, positional_encoding='fixed_abs_sine', normalize_before=True, kernel_size: int = 31, bias: bool = True, encoder_module: str = 'transformer', conformer_activation: type = <class 'speechbrain.nnet.activations.Swish'>, branchformer_activation: type = <class 'torch.nn.modules.activation.GELU'>, attention_type: str = 'regularMHA', max_length: int = 2500, causal: bool = False, encoder_kdim: int | None = None, encoder_vdim: int | None = None, decoder_kdim: int | None = None, decoder_vdim: int | None = None, csgu_linear_units: int = 3072, gate_activation: type = <class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv: bool = False, output_hidden_states=False, layerdrop_prob=0.0)[source]

Bases: Module

This is an interface for transformer model. Users can modify the attributes and define the forward function as needed according to their own tasks. The architecture is based on the paper β€œAttention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf

Parameters:
  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • nhead (int) – The number of heads in the multi-head attention models (default=8).

  • num_encoder_layers (int, optional) – The number of encoder layers in1Γ¬ the encoder.

  • num_decoder_layers (int, optional) – The number of decoder layers in the decoder.

  • d_ffn (int, optional) – The dimension of the feedforward network model hidden layer.

  • dropout (int, optional) – The dropout value.

  • activation (torch.nn.Module, optional) – The activation function for Feed-Forward Network layer, e.g., relu or gelu or swish.

  • custom_src_module (torch.nn.Module, optional) – Module that processes the src features to expected feature dim.

  • custom_tgt_module (torch.nn.Module, optional) – Module that processes the src features to expected feature dim.

  • positional_encoding (str, optional) – Type of positional encoding used. e.g. β€˜fixed_abs_sine’ for fixed absolute positional encodings.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • kernel_size (int, optional) – Kernel size in convolutional layers when Conformer is used.

  • bias (bool, optional) – Whether to use bias in Conformer convolutional layers.

  • encoder_module (str, optional) – Choose between Branchformer, Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.

  • conformer_activation (torch.nn.Module, optional) – Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.

  • branchformer_activation (torch.nn.Module, optional) – Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • max_length (int, optional) – Max length for the target and source sequence in input. Used for positional encodings.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • encoder_kdim (int, optional) – Dimension of the key for the encoder.

  • encoder_vdim (int, optional) – Dimension of the value for the encoder.

  • decoder_kdim (int, optional) – Dimension of the key for the decoder.

  • decoder_vdim (int, optional) – Dimension of the value for the decoder.

  • csgu_linear_units (int, optional) – Number of neurons in the hidden linear units of the CSGU Module. -> Branchformer

  • gate_activation (torch.nn.Module, optional) – Activation function used at the gate of the CSGU module. -> Branchformer

  • use_linear_after_conv (bool, optional) – If True, will apply a linear transformation of size input_size//2. -> Branchformer

  • output_hidden_states (bool, optional) – Whether the model should output the hidden states as a list of tensor.

  • layerdrop_prob (float) – The probability to drop an entire layer.

forward(**kwags)[source]

Users should modify this function according to their own tasks.

class speechbrain.lobes.models.transformer.Transformer.PositionalEncoding(input_size, max_len=2500)[source]

Bases: Module

This class implements the absolute sinusoidal positional encoding function. PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))

Parameters:
  • input_size (int) – Embedding dimension.

  • max_len (int, optional) – Max length of the input sequences (default 2500).

Example

>>> a = torch.rand((8, 120, 512))
>>> enc = PositionalEncoding(input_size=a.shape[-1])
>>> b = enc(a)
>>> b.shape
torch.Size([1, 120, 512])
forward(x)[source]
Parameters:

x (torch.Tensor) – Input feature shape (batch, time, fea)

Return type:

The positional encoding.

class speechbrain.lobes.models.transformer.Transformer.TransformerEncoderLayer(d_ffn, nhead, d_model, kdim=None, vdim=None, dropout=0.0, activation: type = <class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, attention_type='regularMHA', ffn_type='regularFFN', ffn_cnn_kernel_size_list=[3, 3], causal=False)[source]

Bases: Module

This is an implementation of self-attention encoder layer.

Parameters:
  • d_ffn (int, optional) – The dimension of the feedforward network model hidden layer.

  • nhead (int) – The number of heads in the multi-head attention models (default=8).

  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • kdim (int, optional) – Dimension of the key.

  • vdim (int, optional) – Dimension of the value.

  • dropout (int, optional) – The dropout value.

  • activation (torch.nn.Module, optional) – The activation function for Feed-Forward Network layer, e.g., relu or gelu or swish.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • ffn_type (str) – type of ffn: regularFFN/1dcnn

  • ffn_cnn_kernel_size_list (list of int) – kernel size of 2 1d-convs if ffn_type is 1dcnn

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoderLayer(512, 8, d_model=512)
>>> output = net(x)
>>> output[0].shape
torch.Size([8, 60, 512])
forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None)[source]
Parameters:
  • src (torch.Tensor) – The sequence to the encoder layer.

  • src_mask (torch.Tensor) – The mask for the src query for each example in the batch.

  • src_key_padding_mask (torch.Tensor, optional) – The mask for the src keys for each example in the batch.

  • pos_embs (torch.Tensor, optional) – The positional embeddings tensor.

Returns:

output – The output of the transformer encoder layer.

Return type:

torch.Tensor

class speechbrain.lobes.models.transformer.Transformer.TransformerEncoder(num_layers, nhead, d_ffn, input_shape=None, d_model=None, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, causal=False, layerdrop_prob=0.0, attention_type='regularMHA', ffn_type='regularFFN', ffn_cnn_kernel_size_list=[3, 3], output_hidden_states=False)[source]

Bases: Module

This class implements the transformer encoder.

Parameters:
  • num_layers (int) – Number of transformer layers to include.

  • nhead (int) – Number of attention heads.

  • d_ffn (int) – Hidden size of self-attention Feed Forward layer.

  • input_shape (tuple) – Expected shape of the input.

  • d_model (int) – The dimension of the input embedding.

  • kdim (int) – Dimension for key (Optional).

  • vdim (int) – Dimension for value (Optional).

  • dropout (float) – Dropout for the encoder (Optional).

  • activation (torch.nn.Module, optional) – The activation function for Feed-Forward Network layer, e.g., relu or gelu or swish.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • layerdrop_prob (float) – The probability to drop an entire layer

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • ffn_type (str) – type of ffn: regularFFN/1dcnn

  • ffn_cnn_kernel_size_list (list of int) – conv kernel size of 2 1d-convs if ffn_type is 1dcnn

  • output_hidden_states (bool, optional) – Whether the model should output the hidden states as a list of tensor.

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(1, 8, 512, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(
...     1, 8, 512, d_model=512, output_hidden_states=True
... )
>>> output, attn_list, hidden_list = net(x)
>>> hidden_list[0].shape
torch.Size([8, 60, 512])
>>> len(hidden_list)
2
forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None, dynchunktrain_config=None)[source]
Parameters:
  • src (torch.Tensor) – The sequence to the encoder layer (required).

  • src_mask (torch.Tensor) – The mask for the src sequence (optional).

  • src_key_padding_mask (torch.Tensor) – The mask for the src keys per batch (optional).

  • pos_embs (torch.Tensor) – The positional embedding tensor

  • dynchunktrain_config (config) – Not supported for this encoder.

Returns:

  • output (torch.Tensor) – The output of the transformer.

  • attention_lst (list) – The attention values.

  • hidden_state_lst (list, optional) – The output of the hidden layers of the encoder. Only works if output_hidden_states is set to true.

class speechbrain.lobes.models.transformer.Transformer.TransformerDecoderLayer(d_ffn, nhead, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, attention_type='regularMHA', causal=None)[source]

Bases: Module

This class implements the self-attention decoder layer.

Parameters:
  • d_ffn (int) – Hidden size of self-attention Feed Forward layer.

  • nhead (int) – Number of attention heads.

  • d_model (int) – Dimension of the model.

  • kdim (int) – Dimension for key (optional).

  • vdim (int) – Dimension for value (optional).

  • dropout (float) – Dropout for the decoder (optional).

  • activation (Callable) – Function to use between layers, default nn.ReLU

  • normalize_before (bool) – Whether to normalize before layers.

  • attention_type (str) – Type of attention to use, β€œregularMHA” or β€œRelPosMHAXL”

  • causal (bool) – Whether to mask future positions.

Example

>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoderLayer(1024, 8, d_model=512)
>>> output, self_attn, multihead_attn = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, pos_embs_tgt=None, pos_embs_src=None)[source]
Parameters:
  • tgt (torch.Tensor) – The sequence to the decoder layer (required).

  • memory (torch.Tensor) – The sequence from the last layer of the encoder (required).

  • tgt_mask (torch.Tensor) – The mask for the tgt sequence (optional).

  • memory_mask (torch.Tensor) – The mask for the memory sequence (optional).

  • tgt_key_padding_mask (torch.Tensor) – The mask for the tgt keys per batch (optional).

  • memory_key_padding_mask (torch.Tensor) – The mask for the memory keys per batch (optional).

  • pos_embs_tgt (torch.Tensor) – The positional embeddings for the target (optional).

  • pos_embs_src (torch.Tensor) – The positional embeddings for the source (optional).

class speechbrain.lobes.models.transformer.Transformer.TransformerDecoder(num_layers, nhead, d_ffn, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, causal=False, attention_type='regularMHA')[source]

Bases: Module

This class implements the Transformer decoder.

Parameters:
  • num_layers (int) – Number of transformer layers for the decoder.

  • nhead (int) – Number of attention heads.

  • d_ffn (int) – Hidden size of self-attention Feed Forward layer.

  • d_model (int) – Dimension of the model.

  • kdim (int, optional) – Dimension for key (Optional).

  • vdim (int, optional) – Dimension for value (Optional).

  • dropout (float, optional) – Dropout for the decoder (Optional).

  • activation (Callable) – The function to apply between layers, default nn.ReLU

  • normalize_before (bool) – Whether to normalize before layers.

  • causal (bool) – Whether to allow future information in decoding.

  • attention_type (str) – Type of attention to use, β€œregularMHA” or β€œRelPosMHAXL”

Example

>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoder(1, 8, 1024, d_model=512)
>>> output, _, _ = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, pos_embs_tgt=None, pos_embs_src=None)[source]
Parameters:
  • tgt (torch.Tensor) – The sequence to the decoder layer (required).

  • memory (torch.Tensor) – The sequence from the last layer of the encoder (required).

  • tgt_mask (torch.Tensor) – The mask for the tgt sequence (optional).

  • memory_mask (torch.Tensor) – The mask for the memory sequence (optional).

  • tgt_key_padding_mask (torch.Tensor) – The mask for the tgt keys per batch (optional).

  • memory_key_padding_mask (torch.Tensor) – The mask for the memory keys per batch (optional).

  • pos_embs_tgt (torch.Tensor) – The positional embeddings for the target (optional).

  • pos_embs_src (torch.Tensor) – The positional embeddings for the source (optional).

class speechbrain.lobes.models.transformer.Transformer.NormalizedEmbedding(d_model, vocab)[source]

Bases: Module

This class implements the normalized embedding layer for the transformer. Since the dot product of the self-attention is always normalized by sqrt(d_model) and the final linear projection for prediction shares weight with the embedding layer, we multiply the output of the embedding by sqrt(d_model).

Parameters:
  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • vocab (int) – The vocab size.

Example

>>> emb = NormalizedEmbedding(512, 1000)
>>> trg = torch.randint(0, 999, (8, 50))
>>> emb_fea = emb(trg)
forward(x)[source]

Processes the input tensor x and returns an output tensor.

speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask(padded_input, pad_idx)[source]

Creates a binary mask to prevent attention to padded locations. We suggest using get_mask_from_lengths instead of this function.

Parameters:
  • padded_input (torch.Tensor) – Padded input.

  • pad_idx (int) – idx for padding element.

Returns:

key_padded_mask – Binary mask to prevent attention to padding.

Return type:

torch.Tensor

Example

>>> a = torch.LongTensor([[1, 1, 0], [2, 3, 0], [4, 5, 0]])
>>> get_key_padding_mask(a, pad_idx=0)
tensor([[False, False,  True],
        [False, False,  True],
        [False, False,  True]])
speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask(padded_input)[source]

Creates a binary mask for each sequence which masks future frames.

Parameters:

padded_input (torch.Tensor) – Padded input tensor.

Returns:

mask – Binary mask for masking future frames.

Return type:

torch.Tensor

Example

>>> a = torch.LongTensor([[1, 1, 0], [2, 3, 0], [4, 5, 0]])
>>> get_lookahead_mask(a)
tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
speechbrain.lobes.models.transformer.Transformer.get_mask_from_lengths(lengths, max_len=None)[source]

Creates a binary mask from sequence lengths

Parameters:
  • lengths (torch.Tensor) – A tensor of sequence lengths

  • max_len (int (Optional)) – Maximum sequence length, defaults to None.

Returns:

mask – the mask where padded elements are set to True. Then one can use tensor.masked_fill_(mask, 0) for the masking.

Return type:

torch.Tensor

Example

>>> lengths = torch.tensor([3, 2, 4])
>>> get_mask_from_lengths(lengths)
tensor([[False, False, False,  True],
        [False, False,  True,  True],
        [False, False, False, False]])