speechbrain.lobes.models.transformer.TransformerASR module

Transformer for ASR in the SpeechBrain style.

Authors * Jianyuan Zhong 2020 * Titouan Parcollet 2024 * Luca Della Libera 2024

Summary

Classes:

EncoderWrapper

This is a wrapper of any ASR transformer encoder.

TransformerASR

This is an implementation of transformer model for ASR.

TransformerASRStreamingContext

Streaming metadata and state for a TransformerASR instance.

Functions:

make_transformer_src_mask

Prepare the source transformer mask that restricts which frames can attend to which frames depending on causal or other simple restricted attention methods.

make_transformer_src_tgt_masks

This function generates masks for training the transformer model, opiniated for an ASR context with encoding masks and, optionally, decoding masks (if specifying tgt).

Reference

class speechbrain.lobes.models.transformer.TransformerASR.TransformerASRStreamingContext(dynchunktrain_config: DynChunkTrainConfig, encoder_context: Any)[source]

Bases: object

Streaming metadata and state for a TransformerASR instance.

dynchunktrain_config: DynChunkTrainConfig

Dynamic Chunk Training configuration holding chunk size and context size information.

encoder_context: Any

Opaque encoder context information. It is constructed by the encoder’s make_streaming_context method and is passed to the encoder when using encode_streaming.

speechbrain.lobes.models.transformer.TransformerASR.make_transformer_src_mask(src: Tensor, causal: bool = False, dynchunktrain_config: DynChunkTrainConfig | None = None) Tensor | None[source]

Prepare the source transformer mask that restricts which frames can attend to which frames depending on causal or other simple restricted attention methods.

Parameters:
  • src (torch.Tensor) – The source tensor to build a mask from. The contents of the tensor are not actually used currently; only its shape and other metadata (e.g. device).

  • causal (bool) – Whether strict causality shall be used. Frames will not be able to attend to any future frame.

  • dynchunktrain_config (DynChunkTrainConfig, optional) – Dynamic Chunk Training configuration. This implements a simple form of chunkwise attention. Incompatible with causal.

Returns:

A boolean mask Tensor of shape (timesteps, timesteps).

Return type:

torch.Tensor

speechbrain.lobes.models.transformer.TransformerASR.make_transformer_src_tgt_masks(src, tgt=None, wav_len=None, pad_idx=0, causal: bool = False, dynchunktrain_config: DynChunkTrainConfig | None = None)[source]

This function generates masks for training the transformer model, opiniated for an ASR context with encoding masks and, optionally, decoding masks (if specifying tgt).

Parameters:
  • src (tensor) – The sequence to the encoder (required).

  • tgt (tensor) – The sequence to the decoder.

  • pad_idx (int) – The index for <pad> token (default=0).

  • causal (bool) – Whether strict causality shall be used. See make_asr_src_mask

  • dynchunktrain_config (DynChunkTrainConfig, optional) – Dynamic Chunk Training configuration. See make_asr_src_mask

class speechbrain.lobes.models.transformer.TransformerASR.TransformerASR(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, kernel_size: int | None = 31, bias: bool | None = True, encoder_module: str | None = 'transformer', conformer_activation: ~torch.nn.modules.module.Module | None = <class 'speechbrain.nnet.activations.Swish'>, branchformer_activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.GELU'>, attention_type: str | None = 'regularMHA', max_length: int | None = 2500, causal: bool | None = True, csgu_linear_units: int | None = 3072, gate_activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv: bool | None = False)[source]

Bases: TransformerInterface

This is an implementation of transformer model for ASR.

The architecture is based on the paper “Attention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf

Parameters:
  • tgt_vocab (int) – Size of vocabulary.

  • input_size (int) – Input feature size.

  • d_model (int, optional) – Embedding dimension size. (default=512).

  • nhead (int, optional) – The number of heads in the multi-head attention models (default=8).

  • num_encoder_layers (int, optional) – The number of sub-encoder-layers in the encoder (default=6).

  • num_decoder_layers (int, optional) – The number of sub-decoder-layers in the decoder (default=6).

  • dim_ffn (int, optional) – The dimension of the feedforward network model (default=2048).

  • dropout (int, optional) – The dropout value (default=0.1).

  • activation (torch.nn.Module, optional) – The activation function of FFN layers. Recommended: relu or gelu (default=relu).

  • positional_encoding (str, optional) – Type of positional encoding used. e.g. ‘fixed_abs_sine’ for fixed absolute positional encodings.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • kernel_size (int, optional) – Kernel size in convolutional layers when Conformer is used.

  • bias (bool, optional) – Whether to use bias in Conformer convolutional layers.

  • encoder_module (str, optional) – Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.

  • conformer_activation (torch.nn.Module, optional) – Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.

  • branchformer_activation (torch.nn.Module, optional) – Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • max_length (int, optional) – Max length for the target and source sequence in input. Used for positional encodings.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • csgu_linear_units (int, optional) – Number of neurons in the hidden linear units of the CSGU Module. -> Branchformer

  • gate_activation (torch.nn.Module, optional) – Activation function used at the gate of the CSGU module. -> Branchformer

  • use_linear_after_conv (bool, optional) – If True, will apply a linear transformation of size input_size//2. -> Branchformer

Example

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerASR(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> enc_out, dec_out = net.forward(src, tgt)
>>> enc_out.shape
torch.Size([8, 120, 512])
>>> dec_out.shape
torch.Size([8, 120, 512])
forward(src, tgt, wav_len=None, pad_idx=0)[source]
Parameters:
  • src (torch.Tensor) – The sequence to the encoder.

  • tgt (torch.Tensor) – The sequence to the decoder.

  • wav_len (torch.Tensor, optional) – Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.

  • pad_idx (int, optional) – The index for <pad> token (default=0).

decode(tgt, encoder_out, enc_len=None)[source]

This method implements a decoding step for the transformer model.

Parameters:
  • tgt (torch.Tensor) – The sequence to the decoder.

  • encoder_out (torch.Tensor) – Hidden output of the encoder.

  • enc_len (torch.LongTensor) – The actual length of encoder states.

encode(src, wav_len=None, pad_idx=0, dynchunktrain_config: DynChunkTrainConfig | None = None)[source]

Encoder forward pass

Parameters:
  • src (torch.Tensor) – The sequence to the encoder.

  • wav_len (torch.Tensor, optional) – Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.

encode_streaming(src, context: TransformerASRStreamingContext)[source]

Streaming encoder forward pass

Parameters:
  • src (torch.Tensor) – The sequence (chunk) to the encoder.

  • context (TransformerASRStreamingContext) – Mutable reference to the streaming context. This holds the state needed to persist across chunk inferences and can be built using make_streaming_context. This will get mutated by this function.

Return type:

Encoder output for this chunk.

Example

>>> import torch
>>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
>>> net = TransformerASR(
...     tgt_vocab=100,
...     input_size=64,
...     d_model=64,
...     nhead=8,
...     num_encoder_layers=1,
...     num_decoder_layers=0,
...     d_ffn=128,
...     attention_type="RelPosMHAXL",
...     positional_encoding=None,
...     encoder_module="conformer",
...     normalize_before=True,
...     causal=False,
... )
>>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1))
>>> src1 = torch.rand([8, 16, 64])
>>> src2 = torch.rand([8, 16, 64])
>>> out1 = net.encode_streaming(src1, ctx)
>>> out1.shape
torch.Size([8, 16, 64])
>>> ctx.encoder_context.layers[0].mha_left_context.shape
torch.Size([8, 16, 64])
>>> out2 = net.encode_streaming(src2, ctx)
>>> out2.shape
torch.Size([8, 16, 64])
>>> ctx.encoder_context.layers[0].mha_left_context.shape
torch.Size([8, 16, 64])
>>> combined_out = torch.concat((out1, out2), dim=1)
>>> combined_out.shape
torch.Size([8, 32, 64])
make_streaming_context(dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={})[source]

Creates a blank streaming context for this transformer and its encoder.

Parameters:
  • dynchunktrain_config (DynChunkTrainConfig) – Runtime chunkwise attention configuration.

  • encoder_kwargs (dict) – Parameters to be forward to the encoder’s make_streaming_context. Metadata required for the encoder could differ depending on the encoder.

training: bool
class speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper(transformer, *args, **kwargs)[source]

Bases: Module

This is a wrapper of any ASR transformer encoder. By default, the TransformerASR .forward() function encodes and decodes. With this wrapper the .forward() function becomes .encode() only.

Important: The TransformerASR class must contain a .encode() function.

Parameters:

transformer (sb.lobes.models.TransformerInterface) – A Transformer instance that contains a .encode() function.

Example

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerASR(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> encoder = EncoderWrapper(net)
>>> enc_out = encoder(src)
>>> enc_out.shape
torch.Size([8, 120, 512])
forward(x, wav_lens=None, pad_idx=0, **kwargs)[source]

Processes the input tensor x and returns an output tensor.

training: bool
forward_streaming(x, context)[source]

Processes the input audio chunk tensor x, using and updating the mutable encoder context

make_streaming_context(*args, **kwargs)[source]

Initializes a streaming context. Forwards all arguments to the underlying transformer. See speechbrain.lobes.models.transformer.TransformerASR.make_streaming_context().