Source code for speechbrain.lobes.models.transformer.TransformerASR

"""Transformer for ASR in the SpeechBrain sytle.

Authors
* Jianyuan Zhong 2020
"""

import torch  # noqa 42
from torch import nn
from typing import Optional

from speechbrain.nnet.linear import Linear
from speechbrain.nnet.containers import ModuleList
from speechbrain.lobes.models.transformer.Transformer import (
    TransformerInterface,
    get_lookahead_mask,
    get_key_padding_mask,
    NormalizedEmbedding,
)
from speechbrain.nnet.activations import Swish

from speechbrain.dataio.dataio import length_to_mask


[docs]class TransformerASR(TransformerInterface): """This is an implementation of transformer model for ASR. The architecture is based on the paper "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf Arguments ---------- d_model : int The number of expected features in the encoder/decoder inputs (default=512). nhead : int The number of heads in the multi-head attention models (default=8). num_encoder_layers : int The number of sub-encoder-layers in the encoder (default=6). num_decoder_layers : int The number of sub-decoder-layers in the decoder (default=6). dim_ffn : int The dimension of the feedforward network model (default=2048). dropout : int The dropout value (default=0.1). activation : torch class The activation function of encoder/decoder intermediate layer. Recommended: relu or gelu (default=relu). Example ------- >>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerASR( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> enc_out, dec_out = net.forward(src, tgt) >>> enc_out.shape torch.Size([8, 120, 512]) >>> dec_out.shape torch.Size([8, 120, 512]) """ def __init__( self, tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=nn.ReLU, positional_encoding=True, normalize_before=False, kernel_size: Optional[int] = 31, bias: Optional[bool] = True, encoder_module: Optional[str] = "transformer", conformer_activation: Optional[nn.Module] = Swish, ): super().__init__( d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_ffn=d_ffn, dropout=dropout, activation=activation, positional_encoding=positional_encoding, normalize_before=normalize_before, kernel_size=kernel_size, bias=bias, encoder_module=encoder_module, conformer_activation=conformer_activation, ) self.custom_src_module = ModuleList( Linear( input_size=input_size, n_neurons=d_model, bias=True, combine_dims=False, ), torch.nn.Dropout(dropout), ) self.custom_tgt_module = ModuleList( NormalizedEmbedding(d_model, tgt_vocab) ) # reset parameters using xavier_normal_ self._init_params()
[docs] def forward( self, src, tgt, wav_len=None, pad_idx=0, ): """ Arguments ---------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ # reshpae the src vector to [Batch, Time, Fea] is a 4d vector is given if src.dim() == 4: bz, t, ch1, ch2 = src.shape src = src.reshape(bz, t, ch1 * ch2) ( src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask, ) = self.make_masks(src, tgt, wav_len, pad_idx=pad_idx) src = self.custom_src_module(src) src = src + self.positional_encoding(src) encoder_out, _ = self.encoder( src=src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, ) tgt = self.custom_tgt_module(tgt) tgt = tgt + self.positional_encoding(tgt) decoder_out, _, _ = self.decoder( tgt=tgt, memory=encoder_out, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, ) return encoder_out, decoder_out
[docs] def make_masks(self, src, tgt, wav_len=None, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ src_key_padding_mask = None if wav_len is not None and self.training: abs_len = torch.round(wav_len * src.shape[1]) src_key_padding_mask = (1 - length_to_mask(abs_len)).bool() tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
[docs] def decode(self, tgt, encoder_out): """This method implements a decoding step for the transformer model. Arguments --------- tgt : tensor The sequence to the decoder (required). encoder_out : tensor Hidden output of the encoder (required). """ tgt_mask = get_lookahead_mask(tgt) tgt = self.custom_tgt_module(tgt) tgt = tgt + self.positional_encoding(tgt) prediction, self_attns, multihead_attns = self.decoder( tgt, encoder_out, tgt_mask=tgt_mask ) return prediction, multihead_attns[-1]
[docs] def encode( self, src, wav_len=None, ): """ forward the encoder with source input Arguments ---------- src : tensor The sequence to the encoder (required). """ # reshape the src vector to [Batch, Time, Fea] if a 4d vector is given if src.dim() == 4: bz, t, ch1, ch2 = src.shape src = src.reshape(bz, t, ch1 * ch2) src_key_padding_mask = None if wav_len is not None and self.training: abs_len = torch.round(wav_len * src.shape[1]) src_key_padding_mask = (1 - length_to_mask(abs_len)).bool() src = self.custom_src_module(src) src = src + self.positional_encoding(src) encoder_out, _ = self.encoder( src=src, src_key_padding_mask=src_key_padding_mask ) return encoder_out
def _init_params(self): for p in self.parameters(): if p.dim() > 1: torch.nn.init.xavier_normal_(p)
[docs]class EncoderWrapper(nn.Module): """This is a wrapper of any ASR transformer encoder. By default, the TransformerASR .forward() function encodes and decodes. With this wrapper the .forward() function becomes .encode() only. Important: The TransformerASR class must contain a .encode() function. Arguments ---------- transformer : sb.lobes.models.TransformerInterface A Transformer instance that contains a .encode() function. Example ------- >>> src = torch.rand([8, 120, 512]) >>> tgt = torch.randint(0, 720, [8, 120]) >>> net = TransformerASR( ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> encoder = EncoderWrapper(net) >>> enc_out = encoder(src) >>> enc_out.shape torch.Size([8, 120, 512]) """ def __init__(self, transformer, *args, **kwargs): super().__init__(*args, **kwargs) self.transformer = transformer
[docs] def forward(self, x, wav_lens=None): x = self.transformer.encode(x, wav_lens) return x