Source code for speechbrain.lobes.models.transformer.TransformerST

"""Transformer for ST in the SpeechBrain style.

Authors
* YAO FEI, CHENG 2021
"""

from typing import Optional

import torch  # noqa 42
from torch import nn

from speechbrain.lobes.models.transformer.Conformer import ConformerEncoder
from speechbrain.lobes.models.transformer.Transformer import (
    NormalizedEmbedding,
    TransformerDecoder,
    TransformerEncoder,
    get_key_padding_mask,
    get_lookahead_mask,
)
from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.containers import ModuleList
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


[docs] class TransformerST(TransformerASR): """This is an implementation of transformer model for ST. The architecture is based on the paper "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf Arguments --------- tgt_vocab: int Size of vocabulary. input_size: int Input feature size. d_model : int, optional Embedding dimension size. (default=512). nhead : int, optional The number of heads in the multi-head attention models (default=8). num_encoder_layers : int, optional The number of sub-encoder-layers in the encoder (default=6). num_decoder_layers : int, optional The number of sub-decoder-layers in the decoder (default=6). d_ffn : int, optional The dimension of the feedforward network model (default=2048). dropout : int, optional The dropout value (default=0.1). activation : torch.nn.Module, optional The activation function of FFN layers. Recommended: relu or gelu (default=relu). positional_encoding: str, optional Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. normalize_before: bool, optional Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability. kernel_size: int, optional Kernel size in convolutional layers when Conformer is used. bias: bool, optional Whether to use bias in Conformer convolutional layers. encoder_module: str, optional Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer. conformer_activation: torch.nn.Module, optional Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. attention_type: str, optional Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA. max_length: int, optional Max length for the target and source sequence in input. Used for positional encodings. causal: bool, optional Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal. ctc_weight: float The weight of ctc for asr task asr_weight: float The weight of asr task for calculating loss mt_weight: float The weight of mt task for calculating loss asr_tgt_vocab: int The size of the asr target language mt_src_vocab: int The size of the mt source language Example ------- >>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerST( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU, ... ctc_weight=1, asr_weight=0.3, ... ) >>> enc_out, dec_out = net.forward(src, tgt) >>> enc_out.shape torch.Size([8, 120, 512]) >>> dec_out.shape torch.Size([8, 120, 512]) """ def __init__( self, tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=nn.ReLU, positional_encoding="fixed_abs_sine", normalize_before=False, kernel_size: Optional[int] = 31, bias: Optional[bool] = True, encoder_module: Optional[str] = "transformer", conformer_activation: Optional[nn.Module] = Swish, attention_type: Optional[str] = "regularMHA", max_length: Optional[int] = 2500, causal: Optional[bool] = True, ctc_weight: float = 0.0, asr_weight: float = 0.0, mt_weight: float = 0.0, asr_tgt_vocab: int = 0, mt_src_vocab: int = 0, ): super().__init__( tgt_vocab=tgt_vocab, input_size=input_size, d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_ffn=d_ffn, dropout=dropout, activation=activation, positional_encoding=positional_encoding, normalize_before=normalize_before, kernel_size=kernel_size, bias=bias, encoder_module=encoder_module, conformer_activation=conformer_activation, attention_type=attention_type, max_length=max_length, causal=causal, ) if ctc_weight < 1 and asr_weight > 0: self.asr_decoder = TransformerDecoder( num_layers=num_decoder_layers, nhead=nhead, d_ffn=d_ffn, d_model=d_model, dropout=dropout, activation=activation, normalize_before=normalize_before, causal=True, attention_type="regularMHA", # always use regular attention in decoder ) self.custom_asr_tgt_module = ModuleList( NormalizedEmbedding(d_model, asr_tgt_vocab) ) if mt_weight > 0: self.custom_mt_src_module = ModuleList( NormalizedEmbedding(d_model, mt_src_vocab) ) if encoder_module == "transformer": self.mt_encoder = TransformerEncoder( nhead=nhead, num_layers=num_encoder_layers, d_ffn=d_ffn, d_model=d_model, dropout=dropout, activation=activation, normalize_before=normalize_before, causal=self.causal, attention_type=self.attention_type, ) elif encoder_module == "conformer": self.mt_encoder = ConformerEncoder( nhead=nhead, num_layers=num_encoder_layers, d_ffn=d_ffn, d_model=d_model, dropout=dropout, activation=conformer_activation, kernel_size=kernel_size, bias=bias, causal=self.causal, attention_type=self.attention_type, ) assert ( normalize_before ), "normalize_before must be True for Conformer" assert ( conformer_activation is not None ), "conformer_activation must not be None" # reset parameters using xavier_normal_ self._init_params()
[docs] def forward_asr(self, encoder_out, src, tgt, wav_len, pad_idx=0): """This method implements a decoding step for asr task Arguments --------- encoder_out : torch.Tensor The representation of the encoder (required). src : torch.Tensor Input sequence (required). tgt : torch.Tensor The sequence to the decoder (transcription) (required). wav_len : torch.Tensor Length of input tensors (required). pad_idx : int The index for <pad> token (default=0). Returns ------- asr_decoder_out : torch.Tensor One step of asr decoder. """ # reshape the src vector to [Batch, Time, Fea] is a 4d vector is given if src.dim() == 4: bz, t, ch1, ch2 = src.shape src = src.reshape(bz, t, ch1 * ch2) ( src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask, ) = self.make_masks(src, tgt, wav_len, pad_idx=pad_idx) transcription = self.custom_asr_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": transcription = transcription + self.positional_encoding_decoder( transcription ) elif self.attention_type == "fixed_abs_sine": transcription = transcription + self.positional_encoding( transcription ) asr_decoder_out, _, _ = self.asr_decoder( tgt=transcription, memory=encoder_out, memory_mask=src_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask, ) return asr_decoder_out
[docs] def forward_mt(self, src, tgt, pad_idx=0): """This method implements a forward step for mt task Arguments --------- src : torch.Tensor The sequence to the encoder (transcription) (required). tgt : torch.Tensor The sequence to the decoder (translation) (required). pad_idx : int The index for <pad> token (default=0). Returns ------- encoder_out : torch.Tensor Output of encoder decoder_out : torch.Tensor Output of decoder """ ( src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask, ) = self.make_masks_for_mt(src, tgt, pad_idx=pad_idx) src = self.custom_mt_src_module(src) if self.attention_type == "RelPosMHAXL": pos_embs_encoder = self.positional_encoding(src) elif self.positional_encoding_type == "fixed_abs_sine": src = src + self.positional_encoding(src) pos_embs_encoder = None encoder_out, _ = self.mt_encoder( src=src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs_encoder, ) tgt = self.custom_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": # use standard sinusoidal pos encoding in decoder tgt = tgt + self.positional_encoding_decoder(tgt) src = src + self.positional_encoding_decoder(src) elif self.positional_encoding_type == "fixed_abs_sine": tgt = tgt + self.positional_encoding(tgt) decoder_out, _, _ = self.decoder( tgt=tgt, memory=encoder_out, memory_mask=src_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask, ) return encoder_out, decoder_out
[docs] def forward_mt_decoder_only(self, src, tgt, pad_idx=0): """This method implements a forward step for mt task using a wav2vec encoder (same than above, but without the encoder stack) Arguments ---------- src (transcription): torch.Tensor output features from the w2v2 encoder tgt (translation): torch.Tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ ( src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask, ) = self.make_masks_for_mt(src, tgt, pad_idx=pad_idx) tgt = self.custom_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": # use standard sinusoidal pos encoding in decoder tgt = tgt + self.positional_encoding_decoder(tgt) elif self.positional_encoding_type == "fixed_abs_sine": tgt = tgt + self.positional_encoding(tgt) decoder_out, _, multihead = self.decoder( tgt=tgt, memory=src, memory_mask=src_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask, ) return decoder_out
[docs] def decode_asr(self, tgt, encoder_out): """This method implements a decoding step for the transformer model. Arguments --------- tgt : torch.Tensor The sequence to the decoder. encoder_out : torch.Tensor Hidden output of the encoder. Returns ------- prediction : torch.Tensor The predicted outputs. multihead_attns : torch.Tensor The last step of attention. """ tgt_mask = get_lookahead_mask(tgt) tgt = self.custom_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": # we use fixed positional encodings in the decoder tgt = tgt + self.positional_encoding_decoder(tgt) encoder_out = encoder_out + self.positional_encoding_decoder( encoder_out ) elif self.positional_encoding_type == "fixed_abs_sine": tgt = tgt + self.positional_encoding(tgt) # add the encodings here prediction, _, multihead_attns = self.asr_decoder( tgt, encoder_out, tgt_mask=tgt_mask ) return prediction, multihead_attns[-1]
[docs] def make_masks_for_mt(self, src, tgt, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : torch.Tensor The sequence to the encoder (required). tgt : torch.Tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). Returns ------- src_key_padding_mask : torch.Tensor Timesteps to mask due to padding tgt_key_padding_mask : torch.Tensor Timesteps to mask due to padding src_mask : torch.Tensor Timesteps to mask for causality tgt_mask : torch.Tensor Timesteps to mask for causality """ src_key_padding_mask = None if self.training: src_key_padding_mask = get_key_padding_mask(src, pad_idx=pad_idx) tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask