speechbrain.lobes.models.transformer.TransformerST module

Transformer for ST in the SpeechBrain sytle.

Authors * YAO FEI, CHENG 2021




This is an implementation of transformer model for ST.


class speechbrain.lobes.models.transformer.TransformerST.TransformerST(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, kernel_size: ~typing.Optional[int] = 31, bias: ~typing.Optional[bool] = True, encoder_module: ~typing.Optional[str] = 'transformer', conformer_activation: ~typing.Optional[~torch.nn.modules.module.Module] = <class 'speechbrain.nnet.activations.Swish'>, attention_type: ~typing.Optional[str] = 'regularMHA', max_length: ~typing.Optional[int] = 2500, causal: ~typing.Optional[bool] = True, ctc_weight: float = 0.0, asr_weight: float = 0.0, mt_weight: float = 0.0, asr_tgt_vocab: int = 0, mt_src_vocab: int = 0)[source]

Bases: TransformerASR

This is an implementation of transformer model for ST.

The architecture is based on the paper “Attention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf

  • tgt_vocab (int) – Size of vocabulary.

  • input_size (int) – Input feature size.

  • d_model (int, optional) – Embedding dimension size. (default=512).

  • nhead (int, optional) – The number of heads in the multi-head attention models (default=8).

  • num_encoder_layers (int, optional) – The number of sub-encoder-layers in the encoder (default=6).

  • num_decoder_layers (int, optional) – The number of sub-decoder-layers in the decoder (default=6).

  • dim_ffn (int, optional) – The dimension of the feedforward network model (default=2048).

  • dropout (int, optional) – The dropout value (default=0.1).

  • activation (torch.nn.Module, optional) – The activation function of FFN layers. Recommended: relu or gelu (default=relu).

  • positional_encoding (str, optional) – Type of positional encoding used. e.g. ‘fixed_abs_sine’ for fixed absolute positional encodings.

  • normalize_before (bool, optional) – Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability.

  • kernel_size (int, optional) – Kernel size in convolutional layers when Conformer is used.

  • bias (bool, optional) – Whether to use bias in Conformer convolutional layers.

  • encoder_module (str, optional) – Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.

  • conformer_activation (torch.nn.Module, optional) – Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.

  • attention_type (str, optional) – Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA.

  • max_length (int, optional) – Max length for the target and source sequence in input. Used for positional encodings.

  • causal (bool, optional) – Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal.

  • ctc_weight (float) – The weight of ctc for asr task

  • asr_weight (float) – The weight of asr task for calculating loss

  • mt_weight (float) – The weight of mt task for calculating loss

  • asr_tgt_vocab (int) – The size of the asr target language

  • mt_src_vocab (int) – The size of the mt source language


>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerST(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU,
...     ctc_weight=1, asr_weight=0.3,
... )
>>> enc_out, dec_out = net.forward(src, tgt)
>>> enc_out.shape
torch.Size([8, 120, 512])
>>> dec_out.shape
torch.Size([8, 120, 512])
forward_asr(encoder_out, src, tgt, wav_len, pad_idx=0)[source]

This method implements a decoding step for asr task

  • encoder_out (tensor) – The representation of the encoder (required).

  • (transcription) (tgt) – The sequence to the decoder (required).

  • pad_idx (int) – The index for <pad> token (default=0).

forward_mt(src, tgt, pad_idx=0)[source]

This method implements a forward step for mt task

  • (transcription) (src) – The sequence to the encoder (required).

  • (translation) (tgt) – The sequence to the decoder (required).

  • pad_idx (int) – The index for <pad> token (default=0).

decode_asr(tgt, encoder_out)[source]

This method implements a decoding step for the transformer model.

make_masks_for_mt(src, tgt, pad_idx=0)[source]

This method generates the masks for training the transformer model.

  • src (tensor) – The sequence to the encoder (required).

  • tgt (tensor) – The sequence to the decoder (required).

  • pad_idx (int) – The index for <pad> token (default=0).

training: bool