speechbrain.lobes.models.transformer.Transformer module

Transformer implementaion in the SpeechBrain style. Authors * Jianyuan Zhong 2020 * Samuele Cornell 2021

Summary

Classes:

NormalizedEmbedding

This class implements the normalized embedding layer for the transformer.

PositionalEncoding

This class implements the absolute sinusoidal positional encoding function.

TransformerDecoder

This class implements the Transformer decoder.

TransformerDecoderLayer

This class implements the self-attention decoder layer.

TransformerEncoder

This class implements the transformer encoder. :param num_layers: Number of transformer layers to include. :type num_layers: int :param nhead: Number of attention heads. :type nhead: int :param d_ffn: Hidden size of self-attention Feed Forward layer. :type d_ffn: int :param d_model: The dimension of the input embedding. :type d_model: int :param kdim: Dimension for key (Optional). :type kdim: int :param vdim: Dimension for value (Optional). :type vdim: int :param dropout: Dropout for the encoder (Optional). :type dropout: float :param input_module: The module to process the source input feature to expected feature dimension (Optional). :type input_module: torch class :param activation: The activation function for Feed-Forward Netowrk layer, e.g., relu or gelu or swish. :type activation: torch.nn.Module, optional :param normalize_before: Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability. :type normalize_before: bool, optional :param causal: Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal. :type causal: bool, optional :param layerdrop_prob: The probability to drop an entire layer :type layerdrop_prob: float :param attention_type: Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA. :type attention_type: str, optional :param ffn_type: type of ffn: regularFFN/1dcnn :type ffn_type: str :param ffn_cnn_kernel_size_list: conv kernel size of 2 1d-convs if ffn_type is 1dcnn :type ffn_cnn_kernel_size_list: list of int.

TransformerEncoderLayer

This is an implementation of self-attention encoder layer. :param d_ffn: The dimension of the feedforward network model hidden layer. :type d_ffn: int, optional :param nhead: The number of heads in the multi-head attention models (default=8). :type nhead: int :param d_model: The number of expected features in the encoder/decoder inputs (default=512). :type d_model: int :param kdim: Dimension of the key. :type kdim: int, optional :param vdim: Dimension of the value. :type vdim: int, optional :param dropout: The dropout value. :type dropout: int, optional :param activation: The activation function for Feed-Forward Netowrk layer, e.g., relu or gelu or swish. :type activation: torch.nn.Module, optional :param normalize_before: Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability. :type normalize_before: bool, optional :param attention_type: Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA. :type attention_type: str, optional :param ffn_type: type of ffn: regularFFN/1dcnn :type ffn_type: str :param ffn_cnn_kernel_size_list: kernel size of 2 1d-convs if ffn_type is 1dcnn :type ffn_cnn_kernel_size_list: list of int :param causal: Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal. :type causal: bool, optional.

TransformerInterface

This is an interface for transformer model. Users can modify the attributes and define the forward function as needed according to their own tasks. The architecture is based on the paper "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf :param d_model: The number of expected features in the encoder/decoder inputs (default=512). :type d_model: int :param nhead: The number of heads in the multi-head attention models (default=8). :type nhead: int :param num_encoder_layers: The number of encoder layers in1ì the encoder. :type num_encoder_layers: int, optional :param num_decoder_layers: The number of decoder layers in the decoder. :type num_decoder_layers: int, optional :param dim_ffn: The dimension of the feedforward network model hidden layer. :type dim_ffn: int, optional :param dropout: The dropout value. :type dropout: int, optional :param activation: The activation function for Feed-Forward Network layer, e.g., relu or gelu or swish. :type activation: torch.nn.Module, optional :param custom_src_module: Module that processes the src features to expected feature dim. :type custom_src_module: torch.nn.Module, optional :param custom_tgt_module: Module that processes the src features to expected feature dim. :type custom_tgt_module: torch.nn.Module, optional :param positional_encoding: Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. :type positional_encoding: str, optional :param normalize_before: Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability. :type normalize_before: bool, optional :param kernel_size: Kernel size in convolutional layers when Conformer is used. :type kernel_size: int, optional :param bias: Whether to use bias in Conformer convolutional layers. :type bias: bool, optional :param encoder_module: Choose between Branchformer, Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer. :type encoder_module: str, optional :param conformer_activation: Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. :type conformer_activation: torch.nn.Module, optional :param branchformer_activation: Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module. :type branchformer_activation: torch.nn.Module, optional :param attention_type: Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA. :type attention_type: str, optional :param max_length: Max length for the target and source sequence in input. Used for positional encodings. :type max_length: int, optional :param causal: Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal. :type causal: bool, optional :param encoder_kdim: Dimension of the key for the encoder. :type encoder_kdim: int, optional :param encoder_vdim: Dimension of the value for the encoder. :type encoder_vdim: int, optional :param decoder_kdim: Dimension of the key for the decoder. :type decoder_kdim: int, optional :param decoder_vdim: Dimension of the value for the decoder. :type decoder_vdim: int, optional :param csgu_linear_units: Number of neurons in the hidden linear units of the CSGU Module. -> Branchformer :type csgu_linear_units: int, optional :param gate_activation: Activation function used at the gate of the CSGU module. -> Branchformer :type gate_activation: torch.nn.Module, optional :param use_linear_after_conv: If True, will apply a linear transformation of size input_size//2. -> Branchformer :type use_linear_after_conv: bool, optional.

Functions:

get_key_padding_mask

Creates a binary mask to prevent attention to padded locations.

get_lookahead_mask

Creates a binary mask for each sequence which maskes future frames.

Reference

class speechbrain.lobes.models.transformer.Transformer.TransformerInterface(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, custom_src_module=None, custom_tgt_module=None, positional_encoding='fixed_abs_sine', normalize_before=True, kernel_size: int | None = 31, bias: bool | None = True, encoder_module: str | None = 'transformer', conformer_activation: ~torch.nn.modules.module.Module | None = <class 'speechbrain.nnet.activations.Swish'>, branchformer_activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.GELU'>, attention_type: str | None = 'regularMHA', max_length: int | None = 2500, causal: bool | None = False, encoder_kdim: int | None = None, encoder_vdim: int | None = None, decoder_kdim: int | None = None, decoder_vdim: int | None = None, csgu_linear_units: int | None = 3072, gate_activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv: bool | None = False)[source]

Bases: Module

This is an interface for transformer model. Users can modify the attributes and define the forward function as needed according to their own tasks. The architecture is based on the paper “Attention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf :param d_model: The number of expected features in the encoder/decoder inputs (default=512). :type d_model: int :param nhead: The number of heads in the multi-head attention models (default=8). :type nhead: int :param num_encoder_layers: The number of encoder layers in1ì the encoder. :type num_encoder_layers: int, optional :param num_decoder_layers: The number of decoder layers in the decoder. :type num_decoder_layers: int, optional :param dim_ffn: The dimension of the feedforward network model hidden layer. :type dim_ffn: int, optional :param dropout: The dropout value. :type dropout: int, optional :param activation: The activation function for Feed-Forward Network layer,

e.g., relu or gelu or swish.

Parameters:
  • custom_src_module (torch.nn.Module, optional) – Module that processes the src features to expected feature dim.

  • custom_tgt_module (torch.nn.Module, optional) – Module that processes the src features to expected feature dim.

  • positional_encoding (str, optional) – Type of positional encoding used. e.g. ‘fixed_abs_sine’ for fixed absolute positional encodings.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • kernel_size (int, optional) – Kernel size in convolutional layers when Conformer is used.

  • bias (bool, optional) – Whether to use bias in Conformer convolutional layers.

  • encoder_module (str, optional) – Choose between Branchformer, Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.

  • conformer_activation (torch.nn.Module, optional) – Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.

  • branchformer_activation (torch.nn.Module, optional) – Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • max_length (int, optional) – Max length for the target and source sequence in input. Used for positional encodings.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • encoder_kdim (int, optional) – Dimension of the key for the encoder.

  • encoder_vdim (int, optional) – Dimension of the value for the encoder.

  • decoder_kdim (int, optional) – Dimension of the key for the decoder.

  • decoder_vdim (int, optional) – Dimension of the value for the decoder.

  • csgu_linear_units (int, optional) – Number of neurons in the hidden linear units of the CSGU Module. -> Branchformer

  • gate_activation (torch.nn.Module, optional) – Activation function used at the gate of the CSGU module. -> Branchformer

  • use_linear_after_conv (bool, optional) – If True, will apply a linear transformation of size input_size//2. -> Branchformer

forward(**kwags)[source]

Users should modify this function according to their own tasks.

training: bool
class speechbrain.lobes.models.transformer.Transformer.PositionalEncoding(input_size, max_len=2500)[source]

Bases: Module

This class implements the absolute sinusoidal positional encoding function. PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) :param input_size: Embedding dimension. :type input_size: int :param max_len: Max length of the input sequences (default 2500). :type max_len: int, optional

Example

>>> a = torch.rand((8, 120, 512))
>>> enc = PositionalEncoding(input_size=a.shape[-1])
>>> b = enc(a)
>>> b.shape
torch.Size([1, 120, 512])
forward(x)[source]
Parameters:

x (tensor) – Input feature shape (batch, time, fea)

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerEncoderLayer(d_ffn, nhead, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, attention_type='regularMHA', ffn_type='regularFFN', ffn_cnn_kernel_size_list=[3, 3], causal=False)[source]

Bases: Module

This is an implementation of self-attention encoder layer. :param d_ffn: The dimension of the feedforward network model hidden layer. :type d_ffn: int, optional :param nhead: The number of heads in the multi-head attention models (default=8). :type nhead: int :param d_model: The number of expected features in the encoder/decoder inputs (default=512). :type d_model: int :param kdim: Dimension of the key. :type kdim: int, optional :param vdim: Dimension of the value. :type vdim: int, optional :param dropout: The dropout value. :type dropout: int, optional :param activation: The activation function for Feed-Forward Netowrk layer,

e.g., relu or gelu or swish.

Parameters:
  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • ffn_type (str) – type of ffn: regularFFN/1dcnn

  • ffn_cnn_kernel_size_list (list of int) – kernel size of 2 1d-convs if ffn_type is 1dcnn

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoderLayer(512, 8, d_model=512)
>>> output = net(x)
>>> output[0].shape
torch.Size([8, 60, 512])
forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None)[source]
Parameters:
  • src (torch.Tensor) – The sequence to the encoder layer.

  • src_mask (torch.Tensor) – The mask for the src query for each example in the batch.

  • src_key_padding_mask (torch.Tensor, optional) – The mask for the src keys for each example in the batch.

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerEncoder(num_layers, nhead, d_ffn, input_shape=None, d_model=None, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, causal=False, layerdrop_prob=0.0, attention_type='regularMHA', ffn_type='regularFFN', ffn_cnn_kernel_size_list=[3, 3])[source]

Bases: Module

This class implements the transformer encoder. :param num_layers: Number of transformer layers to include. :type num_layers: int :param nhead: Number of attention heads. :type nhead: int :param d_ffn: Hidden size of self-attention Feed Forward layer. :type d_ffn: int :param d_model: The dimension of the input embedding. :type d_model: int :param kdim: Dimension for key (Optional). :type kdim: int :param vdim: Dimension for value (Optional). :type vdim: int :param dropout: Dropout for the encoder (Optional). :type dropout: float :param input_module: The module to process the source input feature to expected

feature dimension (Optional).

Parameters:
  • activation (torch.nn.Module, optional) – The activation function for Feed-Forward Netowrk layer, e.g., relu or gelu or swish.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • layerdrop_prob (float) – The probability to drop an entire layer

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • ffn_type (str) – type of ffn: regularFFN/1dcnn

  • ffn_cnn_kernel_size_list (list of int) – conv kernel size of 2 1d-convs if ffn_type is 1dcnn

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(1, 8, 512, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None)[source]
Parameters:
  • src (tensor) – The sequence to the encoder layer (required).

  • src_mask (tensor) – The mask for the src sequence (optional).

  • src_key_padding_mask (tensor) – The mask for the src keys per batch (optional).

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerDecoderLayer(d_ffn, nhead, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, attention_type='regularMHA', causal=None)[source]

Bases: Module

This class implements the self-attention decoder layer. :param d_ffn: Hidden size of self-attention Feed Forward layer. :type d_ffn: int :param nhead: Number of attention heads. :type nhead: int :param d_model: Dimension of the model. :type d_model: int :param kdim: Dimension for key (optional). :type kdim: int :param vdim: Dimension for value (optional). :type vdim: int :param dropout: Dropout for the decoder (optional). :type dropout: float

Example

>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoderLayer(1024, 8, d_model=512)
>>> output, self_attn, multihead_attn = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, pos_embs_tgt=None, pos_embs_src=None)[source]
Parameters:
  • tgt (tensor) – The sequence to the decoder layer (required).

  • memory (tensor) – The sequence from the last layer of the encoder (required).

  • tgt_mask (tensor) – The mask for the tgt sequence (optional).

  • memory_mask (tensor) – The mask for the memory sequence (optional).

  • tgt_key_padding_mask (tensor) – The mask for the tgt keys per batch (optional).

  • memory_key_padding_mask (tensor) – The mask for the memory keys per batch (optional).

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerDecoder(num_layers, nhead, d_ffn, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, causal=False, attention_type='regularMHA')[source]

Bases: Module

This class implements the Transformer decoder. :param nhead: Number of attention heads. :type nhead: int :param d_ffn: Hidden size of self-attention Feed Forward layer. :type d_ffn: int :param d_model: Dimension of the model. :type d_model: int :param kdim: Dimension for key (Optional). :type kdim: int, optional :param vdim: Dimension for value (Optional). :type vdim: int, optional :param dropout: Dropout for the decoder (Optional). :type dropout: float, optional

Example

>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoder(1, 8, 1024, d_model=512)
>>> output, _, _ = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, pos_embs_tgt=None, pos_embs_src=None)[source]
Parameters:
  • tgt (tensor) – The sequence to the decoder layer (required).

  • memory (tensor) – The sequence from the last layer of the encoder (required).

  • tgt_mask (tensor) – The mask for the tgt sequence (optional).

  • memory_mask (tensor) – The mask for the memory sequence (optional).

  • tgt_key_padding_mask (tensor) – The mask for the tgt keys per batch (optional).

  • memory_key_padding_mask (tensor) – The mask for the memory keys per batch (optional).

training: bool
class speechbrain.lobes.models.transformer.Transformer.NormalizedEmbedding(d_model, vocab)[source]

Bases: Module

This class implements the normalized embedding layer for the transformer. Since the dot product of the self-attention is always normalized by sqrt(d_model) and the final linear projection for prediction shares weight with the embedding layer, we multiply the output of the embedding by sqrt(d_model). :param d_model: The number of expected features in the encoder/decoder inputs (default=512). :type d_model: int :param vocab: The vocab size. :type vocab: int

Example

>>> emb = NormalizedEmbedding(512, 1000)
>>> trg = torch.randint(0, 999, (8, 50))
>>> emb_fea = emb(trg)
training: bool
forward(x)[source]

Processes the input tensor x and returns an output tensor.

speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask(padded_input, pad_idx)[source]

Creates a binary mask to prevent attention to padded locations. :param padded_input: Padded input. :type padded_input: int :param pad_idx: idx for padding element.

Example

>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_key_padding_mask(a, pad_idx=0)
tensor([[False, False,  True],
        [False, False,  True],
        [False, False,  True]])
speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask(padded_input)[source]

Creates a binary mask for each sequence which maskes future frames. :param padded_input: Padded input tensor. :type padded_input: torch.Tensor

Example

>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_lookahead_mask(a)
tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])