speechbrain.lobes.models.transformer.TransformerASR moduleο
Transformer for ASR in the SpeechBrain style.
Authors * Jianyuan Zhong 2020 * Titouan Parcollet 2024 * Luca Della Libera 2024
Summaryο
Classes:
This is a wrapper of any ASR transformer encoder. |
|
This is an implementation of transformer model for ASR. |
|
Streaming metadata and state for a |
Functions:
Prepare the source transformer mask that restricts which frames can attend to which frames depending on causal or other simple restricted attention methods. |
|
This function generates masks for training the transformer model, opinionated for an ASR context with encoding masks and, optionally, decoding masks (if specifying |
Referenceο
- class speechbrain.lobes.models.transformer.TransformerASR.TransformerASRStreamingContext(dynchunktrain_config: DynChunkTrainConfig, encoder_context: Any)[source]ο
Bases:
object
Streaming metadata and state for a
TransformerASR
instance.- dynchunktrain_config: DynChunkTrainConfigο
Dynamic Chunk Training configuration holding chunk size and context size information.
- speechbrain.lobes.models.transformer.TransformerASR.make_transformer_src_mask(src: Tensor, causal: bool = False, dynchunktrain_config: DynChunkTrainConfig | None = None) Tensor | None [source]ο
Prepare the source transformer mask that restricts which frames can attend to which frames depending on causal or other simple restricted attention methods.
- Parameters:
src (torch.Tensor) β The source tensor to build a mask from. The contents of the tensor are not actually used currently; only its shape and other metadata (e.g. device).
causal (bool) β Whether strict causality shall be used. Frames will not be able to attend to any future frame.
dynchunktrain_config (DynChunkTrainConfig, optional) β Dynamic Chunk Training configuration. This implements a simple form of chunkwise attention. Incompatible with
causal
.
- Returns:
A boolean mask Tensor of shape (timesteps, timesteps).
- Return type:
torch.Tensor
- speechbrain.lobes.models.transformer.TransformerASR.make_transformer_src_tgt_masks(src, tgt=None, wav_len=None, pad_idx=0, causal: bool = False, dynchunktrain_config: DynChunkTrainConfig | None = None)[source]ο
This function generates masks for training the transformer model, opinionated for an ASR context with encoding masks and, optionally, decoding masks (if specifying
tgt
).- Parameters:
src (torch.Tensor) β The sequence to the encoder (required).
tgt (torch.Tensor) β The sequence to the decoder.
wav_len (torch.Tensor) β The lengths of the inputs.
pad_idx (int) β The index for <pad> token (default=0).
causal (bool) β Whether strict causality shall be used. See
make_asr_src_mask
dynchunktrain_config (DynChunkTrainConfig, optional) β Dynamic Chunk Training configuration. See
make_asr_src_mask
- Returns:
src_key_padding_mask (torch.Tensor) β Key padding mask for ignoring padding
tgt_key_padding_mask (torch.Tensor) β Key padding mask for ignoring padding
src_mask (torch.Tensor) β Mask for ignoring invalid (e.g. future) timesteps
tgt_mask (torch.Tensor) β Mask for ignoring invalid (e.g. future) timesteps
- class speechbrain.lobes.models.transformer.TransformerASR.TransformerASR(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, kernel_size: int | None = 31, bias: bool | None = True, encoder_module: str | None = 'transformer', conformer_activation: ~torch.nn.modules.module.Module | None = <class 'speechbrain.nnet.activations.Swish'>, branchformer_activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.GELU'>, attention_type: str | None = 'regularMHA', max_length: int | None = 2500, causal: bool | None = None, csgu_linear_units: int | None = 3072, gate_activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv: bool | None = False, output_hidden_states=False, layerdrop_prob=0.0)[source]ο
Bases:
TransformerInterface
This is an implementation of transformer model for ASR.
The architecture is based on the paper βAttention Is All You Needβ: https://arxiv.org/pdf/1706.03762.pdf
- Parameters:
tgt_vocab (int) β Size of vocabulary.
input_size (int) β Input feature size.
d_model (int, optional) β Embedding dimension size. (default=512).
nhead (int, optional) β The number of heads in the multi-head attention models (default=8).
num_encoder_layers (int, optional) β The number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers (int, optional) β The number of sub-decoder-layers in the decoder (default=6).
d_ffn (int, optional) β The dimension of the feedforward network model (default=2048).
dropout (int, optional) β The dropout value (default=0.1).
activation (torch.nn.Module, optional) β The activation function of FFN layers. Recommended: relu or gelu (default=relu).
positional_encoding (str, optional) β Type of positional encoding used. e.g. βfixed_abs_sineβ for fixed absolute positional encodings.
normalize_before (bool, optional) β Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.
kernel_size (int, optional) β Kernel size in convolutional layers when Conformer is used.
bias (bool, optional) β Whether to use bias in Conformer convolutional layers.
encoder_module (str, optional) β Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.
conformer_activation (torch.nn.Module, optional) β Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
branchformer_activation (torch.nn.Module, optional) β Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.
attention_type (str, optional) β Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.
max_length (int, optional) β Max length for the target and source sequence in input. Used for positional encodings.
causal (bool, optional) β Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.
csgu_linear_units (int, optional) β Number of neurons in the hidden linear units of the CSGU Module. -> Branchformer
gate_activation (torch.nn.Module, optional) β Activation function used at the gate of the CSGU module. -> Branchformer
use_linear_after_conv (bool, optional) β If True, will apply a linear transformation of size input_size//2. -> Branchformer
output_hidden_states (bool, optional) β Whether the model should output the hidden states as a list of tensor.
layerdrop_prob (float) β The probability to drop an entire layer.
Example
>>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerASR( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> enc_out, dec_out = net.forward(src, tgt) >>> enc_out.shape torch.Size([8, 120, 512]) >>> dec_out.shape torch.Size([8, 120, 512])
- forward(src, tgt, wav_len=None, pad_idx=0)[source]ο
- Parameters:
src (torch.Tensor) β The sequence to the encoder.
tgt (torch.Tensor) β The sequence to the decoder.
wav_len (torch.Tensor, optional) β Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.
pad_idx (int, optional) β The index for <pad> token (default=0).
- Returns:
encoder_out (torch.Tensor) β The output of the encoder.
decoder_out (torch.Tensor) β The output of the decoder
hidden_state_lst (list, optional) β The output of the hidden layers of the encoder. Only works if output_hidden_states is set to true.
- decode(tgt, encoder_out, enc_len=None)[source]ο
This method implements a decoding step for the transformer model.
- Parameters:
tgt (torch.Tensor) β The sequence to the decoder.
encoder_out (torch.Tensor) β Hidden output of the encoder.
enc_len (torch.LongTensor) β The actual length of encoder states.
- Return type:
prediction
- encode(src, wav_len=None, pad_idx=0, dynchunktrain_config: DynChunkTrainConfig | None = None)[source]ο
Encoder forward pass
- Parameters:
src (torch.Tensor) β The sequence to the encoder.
wav_len (torch.Tensor, optional) β Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.
pad_idx (int) β The index used for padding.
dynchunktrain_config (DynChunkTrainConfig) β Dynamic chunking config.
- Returns:
encoder_out
- Return type:
torch.Tensor
- encode_streaming(src, context: TransformerASRStreamingContext)[source]ο
Streaming encoder forward pass
- Parameters:
src (torch.Tensor) β The sequence (chunk) to the encoder.
context (TransformerASRStreamingContext) β Mutable reference to the streaming context. This holds the state needed to persist across chunk inferences and can be built using
make_streaming_context
. This will get mutated by this function.
- Return type:
Encoder output for this chunk.
Example
>>> import torch >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig >>> net = TransformerASR( ... tgt_vocab=100, ... input_size=64, ... d_model=64, ... nhead=8, ... num_encoder_layers=1, ... num_decoder_layers=0, ... d_ffn=128, ... attention_type="RelPosMHAXL", ... positional_encoding=None, ... encoder_module="conformer", ... normalize_before=True, ... causal=False, ... ) >>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1)) >>> src1 = torch.rand([8, 16, 64]) >>> src2 = torch.rand([8, 16, 64]) >>> out1 = net.encode_streaming(src1, ctx) >>> out1.shape torch.Size([8, 16, 64]) >>> ctx.encoder_context.layers[0].mha_left_context.shape torch.Size([8, 16, 64]) >>> out2 = net.encode_streaming(src2, ctx) >>> out2.shape torch.Size([8, 16, 64]) >>> ctx.encoder_context.layers[0].mha_left_context.shape torch.Size([8, 16, 64]) >>> combined_out = torch.concat((out1, out2), dim=1) >>> combined_out.shape torch.Size([8, 32, 64])
- make_streaming_context(dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={})[source]ο
Creates a blank streaming context for this transformer and its encoder.
- Parameters:
dynchunktrain_config (DynChunkTrainConfig) β Runtime chunkwise attention configuration.
encoder_kwargs (dict) β Parameters to be forward to the encoderβs
make_streaming_context
. Metadata required for the encoder could differ depending on the encoder.
- Return type:
- class speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper(transformer, *args, **kwargs)[source]ο
Bases:
Module
This is a wrapper of any ASR transformer encoder. By default, the TransformerASR .forward() function encodes and decodes. With this wrapper the .forward() function becomes .encode() only.
Important: The TransformerASR class must contain a .encode() function.
- Parameters:
Example
>>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerASR( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> encoder = EncoderWrapper(net) >>> enc_out = encoder(src) >>> enc_out.shape torch.Size([8, 120, 512])
- forward(x, wav_lens=None, pad_idx=0, **kwargs)[source]ο
Processes the input tensor x and returns an output tensor.