speechbrain.lobes.models.transformer.Transformer module

Transformer implementaion in the SpeechBrain style.

Authors * Jianyuan Zhong 2020 * Samuele Cornell 2021

Summary

Classes:

NormalizedEmbedding

This class implements the normalized embedding layer for the transformer.

PositionalEncoding

This class implements the absolute sinusoidal positional encoding function.

TransformerDecoder

This class implements the Transformer decoder.

TransformerDecoderLayer

This class implements the self-attention decoder layer.

TransformerEncoder

This class implements the transformer encoder.

TransformerEncoderLayer

This is an implementation of self-attention encoder layer.

TransformerInterface

This is an interface for transformer model.

Functions:

get_key_padding_mask

Creates a binary mask to prevent attention to padded locations.

get_lookahead_mask

Creates a binary mask for each sequence which maskes future frames.

Reference

class speechbrain.lobes.models.transformer.Transformer.TransformerInterface(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, custom_src_module=None, custom_tgt_module=None, positional_encoding='fixed_abs_sine', normalize_before=True, kernel_size: ~typing.Optional[int] = 31, bias: ~typing.Optional[bool] = True, encoder_module: ~typing.Optional[str] = 'transformer', conformer_activation: ~typing.Optional[~torch.nn.modules.module.Module] = <class 'speechbrain.nnet.activations.Swish'>, attention_type: ~typing.Optional[str] = 'regularMHA', max_length: ~typing.Optional[int] = 2500, causal: ~typing.Optional[bool] = False, encoder_kdim: ~typing.Optional[int] = None, encoder_vdim: ~typing.Optional[int] = None, decoder_kdim: ~typing.Optional[int] = None, decoder_vdim: ~typing.Optional[int] = None)[source]

Bases: Module

This is an interface for transformer model.

Users can modify the attributes and define the forward function as needed according to their own tasks.

The architecture is based on the paper “Attention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf

Parameters
  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • nhead (int) – The number of heads in the multi-head attention models (default=8).

  • num_encoder_layers (int, optional) – The number of encoder layers in1ì the encoder.

  • num_decoder_layers (int, optional) – The number of decoder layers in the decoder.

  • dim_ffn (int, optional) – The dimension of the feedforward network model hidden layer.

  • dropout (int, optional) – The dropout value.

  • activation (torch.nn.Module, optional) – The activation function for Feed-Forward Netowrk layer, e.g., relu or gelu or swish.

  • custom_src_module (torch.nn.Module, optional) – Module that processes the src features to expected feature dim.

  • custom_tgt_module (torch.nn.Module, optional) – Module that processes the src features to expected feature dim.

  • positional_encoding (str, optional) – Type of positional encoding used. e.g. ‘fixed_abs_sine’ for fixed absolute positional encodings.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • kernel_size (int, optional) – Kernel size in convolutional layers when Conformer is used.

  • bias (bool, optional) – Whether to use bias in Conformer convolutional layers.

  • encoder_module (str, optional) – Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.

  • conformer_activation (torch.nn.Module, optional) – Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • max_length (int, optional) – Max length for the target and source sequence in input. Used for positional encodings.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • encoder_kdim (int, optional) – Dimension of the key for the encoder.

  • encoder_vdim (int, optional) – Dimension of the value for the encoder.

  • decoder_kdim (int, optional) – Dimension of the key for the decoder.

  • decoder_vdim (int, optional) – Dimension of the value for the decoder.

forward(**kwags)[source]

Users should modify this function according to their own tasks.

training: bool
class speechbrain.lobes.models.transformer.Transformer.PositionalEncoding(input_size, max_len=2500)[source]

Bases: Module

This class implements the absolute sinusoidal positional encoding function.

PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))

Parameters
  • input_size (int) – Embedding dimension.

  • max_len (int, optional) – Max length of the input sequences (default 2500).

Example

>>> a = torch.rand((8, 120, 512))
>>> enc = PositionalEncoding(input_size=a.shape[-1])
>>> b = enc(a)
>>> b.shape
torch.Size([1, 120, 512])
forward(x)[source]
Parameters

x (tensor) – Input feature shape (batch, time, fea)

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerEncoderLayer(d_ffn, nhead, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, attention_type='regularMHA', causal=False)[source]

Bases: Module

This is an implementation of self-attention encoder layer.

Parameters
  • d_ffn (int, optional) – The dimension of the feedforward network model hidden layer.

  • nhead (int) – The number of heads in the multi-head attention models (default=8).

  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • kdim (int, optional) – Dimension of the key.

  • vdim (int, optional) – Dimension of the value.

  • dropout (int, optional) – The dropout value.

  • activation (torch.nn.Module, optional) – The activation function for Feed-Forward Netowrk layer, e.g., relu or gelu or swish.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoderLayer(512, 8, d_model=512)
>>> output = net(x)
>>> output[0].shape
torch.Size([8, 60, 512])
forward(src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos_embs: Optional[Tensor] = None)[source]
Parameters
  • src (torch.Tensor) – The sequence to the encoder layer.

  • src_mask (torch.Tensor) – The mask for the src query for each example in the batch.

  • src_key_padding_mask (torch.Tensor, optional) – The mask for the src keys for each example in the batch.

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerEncoder(num_layers, nhead, d_ffn, input_shape=None, d_model=None, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, causal=False, attention_type='regularMHA')[source]

Bases: Module

This class implements the transformer encoder.

Parameters
  • num_layers (int) – Number of transformer layers to include.

  • nhead (int) – Number of attention heads.

  • d_ffn (int) – Hidden size of self-attention Feed Forward layer.

  • d_model (int) – The dimension of the input embedding.

  • kdim (int) – Dimension for key (Optional).

  • vdim (int) – Dimension for value (Optional).

  • dropout (float) – Dropout for the encoder (Optional).

  • input_module (torch class) – The module to process the source input feature to expected feature dimension (Optional).

Example

>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(1, 8, 512, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
forward(src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos_embs: Optional[Tensor] = None)[source]
Parameters
  • src (tensor) – The sequence to the encoder layer (required).

  • src_mask (tensor) – The mask for the src sequence (optional).

  • src_key_padding_mask (tensor) – The mask for the src keys per batch (optional).

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerDecoderLayer(d_ffn, nhead, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, attention_type='regularMHA', causal=None)[source]

Bases: Module

This class implements the self-attention decoder layer.

Parameters
  • d_ffn (int) – Hidden size of self-attention Feed Forward layer.

  • nhead (int) – Number of attention heads.

  • d_model (int) – Dimension of the model.

  • kdim (int) – Dimension for key (optional).

  • vdim (int) – Dimension for value (optional).

  • dropout (float) – Dropout for the decoder (optional).

Example

>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoderLayer(1024, 8, d_model=512)
>>> output, self_attn, multihead_attn = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, pos_embs_tgt=None, pos_embs_src=None)[source]
Parameters
  • tgt (tensor) – The sequence to the decoder layer (required).

  • memory (tensor) – The sequence from the last layer of the encoder (required).

  • tgt_mask (tensor) – The mask for the tgt sequence (optional).

  • memory_mask (tensor) – The mask for the memory sequence (optional).

  • tgt_key_padding_mask (tensor) – The mask for the tgt keys per batch (optional).

  • memory_key_padding_mask (tensor) – The mask for the memory keys per batch (optional).

training: bool
class speechbrain.lobes.models.transformer.Transformer.TransformerDecoder(num_layers, nhead, d_ffn, d_model, kdim=None, vdim=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>, normalize_before=False, causal=False, attention_type='regularMHA')[source]

Bases: Module

This class implements the Transformer decoder.

Parameters
  • nhead (int) – Number of attention heads.

  • d_ffn (int) – Hidden size of self-attention Feed Forward layer.

  • d_model (int) – Dimension of the model.

  • kdim (int, optional) – Dimension for key (Optional).

  • vdim (int, optional) – Dimension for value (Optional).

  • dropout (float, optional) – Dropout for the decoder (Optional).

Example

>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoder(1, 8, 1024, d_model=512)
>>> output, _, _ = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, pos_embs_tgt=None, pos_embs_src=None)[source]
Parameters
  • tgt (tensor) – The sequence to the decoder layer (required).

  • memory (tensor) – The sequence from the last layer of the encoder (required).

  • tgt_mask (tensor) – The mask for the tgt sequence (optional).

  • memory_mask (tensor) – The mask for the memory sequence (optional).

  • tgt_key_padding_mask (tensor) – The mask for the tgt keys per batch (optional).

  • memory_key_padding_mask (tensor) – The mask for the memory keys per batch (optional).

training: bool
class speechbrain.lobes.models.transformer.Transformer.NormalizedEmbedding(d_model, vocab)[source]

Bases: Module

This class implements the normalized embedding layer for the transformer.

Since the dot product of the self-attention is always normalized by sqrt(d_model) and the final linear projection for prediction shares weight with the embedding layer, we multiply the output of the embedding by sqrt(d_model).

Parameters
  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • vocab (int) – The vocab size.

Example

>>> emb = NormalizedEmbedding(512, 1000)
>>> trg = torch.randint(0, 999, (8, 50))
>>> emb_fea = emb(trg)
training: bool
forward(x)[source]

Processes the input tensor x and returns an output tensor.

speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask(padded_input, pad_idx)[source]

Creates a binary mask to prevent attention to padded locations.

Parameters
  • padded_input (int) – Padded input.

  • pad_idx – idx for padding element.

Example

>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_key_padding_mask(a, pad_idx=0)
tensor([[False, False,  True],
        [False, False,  True],
        [False, False,  True]])
speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask(padded_input)[source]

Creates a binary mask for each sequence which maskes future frames.

Parameters

padded_input (torch.Tensor) – Padded input tensor.

Example

>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_lookahead_mask(a)
tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])