Source code for speechbrain.lobes.models.transformer.TransformerASR

"""Transformer for ASR in the SpeechBrain sytle.

Authors
* Jianyuan Zhong 2020
"""

import torch  # noqa 42
from torch import nn
from typing import Optional
from speechbrain.nnet.linear import Linear
from speechbrain.nnet.containers import ModuleList
from speechbrain.lobes.models.transformer.Transformer import (
    TransformerInterface,
    get_lookahead_mask,
    get_key_padding_mask,
    NormalizedEmbedding,
)
from speechbrain.nnet.activations import Swish
from speechbrain.dataio.dataio import length_to_mask


[docs]class TransformerASR(TransformerInterface): """This is an implementation of transformer model for ASR. The architecture is based on the paper "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf Arguments ---------- tgt_vocab: int Size of vocabulary. input_size: int Input feature size. d_model : int, optional Embedding dimension size. (default=512). nhead : int, optional The number of heads in the multi-head attention models (default=8). num_encoder_layers : int, optional The number of sub-encoder-layers in the encoder (default=6). num_decoder_layers : int, optional The number of sub-decoder-layers in the decoder (default=6). dim_ffn : int, optional The dimension of the feedforward network model (default=2048). dropout : int, optional The dropout value (default=0.1). activation : torch.nn.Module, optional The activation function of FFN layers. Recommended: relu or gelu (default=relu). positional_encoding: str, optional Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. normalize_before: bool, optional Whether normalization should be applied before or after MHA or FFN in Transformer layers. Defaults to True as this was shown to lead to better performance and training stability. kernel_size: int, optional Kernel size in convolutional layers when Conformer is used. bias: bool, optional Whether to use bias in Conformer convolutional layers. encoder_module: str, optional Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer. conformer_activation: torch.nn.Module, optional Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. attention_type: str, optional Type of attention layer used in all Transformer or Conformer layers. e.g. regularMHA or RelPosMHA. max_length: int, optional Max length for the target and source sequence in input. Used for positional encodings. causal: bool, optional Whether the encoder should be causal or not (the decoder is always causal). If causal the Conformer convolutional layer is causal. Example ------- >>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerASR( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> enc_out, dec_out = net.forward(src, tgt) >>> enc_out.shape torch.Size([8, 120, 512]) >>> dec_out.shape torch.Size([8, 120, 512]) """ def __init__( self, tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=nn.ReLU, positional_encoding="fixed_abs_sine", normalize_before=False, kernel_size: Optional[int] = 31, bias: Optional[bool] = True, encoder_module: Optional[str] = "transformer", conformer_activation: Optional[nn.Module] = Swish, attention_type: Optional[str] = "regularMHA", max_length: Optional[int] = 2500, causal: Optional[bool] = True, ): super().__init__( d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_ffn=d_ffn, dropout=dropout, activation=activation, positional_encoding=positional_encoding, normalize_before=normalize_before, kernel_size=kernel_size, bias=bias, encoder_module=encoder_module, conformer_activation=conformer_activation, attention_type=attention_type, max_length=max_length, causal=causal, ) self.custom_src_module = ModuleList( Linear( input_size=input_size, n_neurons=d_model, bias=True, combine_dims=False, ), torch.nn.Dropout(dropout), ) self.custom_tgt_module = ModuleList( NormalizedEmbedding(d_model, tgt_vocab) ) # reset parameters using xavier_normal_ self._init_params()
[docs] def forward(self, src, tgt, wav_len=None, pad_idx=0): """ Arguments ---------- src : torch.Tensor The sequence to the encoder. tgt : torch.Tensor The sequence to the decoder. wav_len: torch.Tensor, optional Torch Tensor of shape (batch, ) containing the relative length to padded length for each example. pad_idx : int, optional The index for <pad> token (default=0). """ # reshpae the src vector to [Batch, Time, Fea] is a 4d vector is given if src.ndim == 4: bz, t, ch1, ch2 = src.shape src = src.reshape(bz, t, ch1 * ch2) ( src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask, ) = self.make_masks(src, tgt, wav_len, pad_idx=pad_idx) src = self.custom_src_module(src) # add pos encoding to queries if are sinusoidal ones else if self.attention_type == "RelPosMHAXL": pos_embs_encoder = self.positional_encoding(src) elif self.positional_encoding_type == "fixed_abs_sine": src = src + self.positional_encoding(src) # add the encodings here pos_embs_encoder = None encoder_out, _ = self.encoder( src=src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs_encoder, ) tgt = self.custom_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": # use standard sinusoidal pos encoding in decoder tgt = tgt + self.positional_encoding_decoder(tgt) # FIXME we use pos embs also on enc output encoder_out = encoder_out + self.positional_encoding_decoder( encoder_out ) pos_embs_encoder = None # self.positional_encoding(src) pos_embs_target = None elif self.positional_encoding_type == "fixed_abs_sine": tgt = tgt + self.positional_encoding(tgt) pos_embs_target = None pos_embs_encoder = None decoder_out, _, _ = self.decoder( tgt=tgt, memory=encoder_out, memory_mask=src_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask, pos_embs_tgt=pos_embs_target, pos_embs_src=pos_embs_encoder, ) return encoder_out, decoder_out
[docs] def make_masks(self, src, tgt, wav_len=None, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ src_key_padding_mask = None if wav_len is not None: abs_len = torch.round(wav_len * src.shape[1]) src_key_padding_mask = ~length_to_mask(abs_len).bool() tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
[docs] @torch.no_grad() def decode(self, tgt, encoder_out, enc_len=None): """This method implements a decoding step for the transformer model. Arguments --------- tgt : torch.Tensor The sequence to the decoder. encoder_out : torch.Tensor Hidden output of the encoder. enc_len : torch.LongTensor The actual length of encoder states. """ tgt_mask = get_lookahead_mask(tgt) src_key_padding_mask = None if enc_len is not None: src_key_padding_mask = (1 - length_to_mask(enc_len)).bool() tgt = self.custom_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": # we use fixed positional encodings in the decoder tgt = tgt + self.positional_encoding_decoder(tgt) encoder_out = encoder_out + self.positional_encoding_decoder( encoder_out ) # pos_embs_target = self.positional_encoding(tgt) pos_embs_encoder = None # self.positional_encoding(src) pos_embs_target = None elif self.positional_encoding_type == "fixed_abs_sine": tgt = tgt + self.positional_encoding(tgt) # add the encodings here pos_embs_target = None pos_embs_encoder = None prediction, self_attns, multihead_attns = self.decoder( tgt, encoder_out, tgt_mask=tgt_mask, memory_key_padding_mask=src_key_padding_mask, pos_embs_tgt=pos_embs_target, pos_embs_src=pos_embs_encoder, ) return prediction, multihead_attns[-1]
[docs] def encode(self, src, wav_len=None): """ Encoder forward pass Arguments ---------- src : torch.Tensor The sequence to the encoder. wav_len: torch.Tensor, optional Torch Tensor of shape (batch, ) containing the relative length to padded length for each example. """ # reshape the src vector to [Batch, Time, Fea] if a 4d vector is given if src.dim() == 4: bz, t, ch1, ch2 = src.shape src = src.reshape(bz, t, ch1 * ch2) src_key_padding_mask = None if wav_len is not None: abs_len = torch.floor(wav_len * src.shape[1]) src_key_padding_mask = ( torch.arange(src.shape[1])[None, :].to(abs_len) > abs_len[:, None] ) src = self.custom_src_module(src) if self.attention_type == "RelPosMHAXL": pos_embs_source = self.positional_encoding(src) elif self.positional_encoding_type == "fixed_abs_sine": src = src + self.positional_encoding(src) pos_embs_source = None encoder_out, _ = self.encoder( src=src, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs_source, ) return encoder_out
def _init_params(self): for p in self.parameters(): if p.dim() > 1: torch.nn.init.xavier_normal_(p)
[docs]class EncoderWrapper(nn.Module): """This is a wrapper of any ASR transformer encoder. By default, the TransformerASR .forward() function encodes and decodes. With this wrapper the .forward() function becomes .encode() only. Important: The TransformerASR class must contain a .encode() function. Arguments ---------- transformer : sb.lobes.models.TransformerInterface A Transformer instance that contains a .encode() function. Example ------- >>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerASR( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> encoder = EncoderWrapper(net) >>> enc_out = encoder(src) >>> enc_out.shape torch.Size([8, 120, 512]) """ def __init__(self, transformer, *args, **kwargs): super().__init__(*args, **kwargs) self.transformer = transformer
[docs] def forward(self, x, wav_lens=None): """ Processes the input tensor x and returns an output tensor.""" x = self.transformer.encode(x, wav_lens) return x