speechbrain.lobes.models.transformer.TransformerASR module

Transformer for ASR in the SpeechBrain sytle.

Authors * Jianyuan Zhong 2020

Summary

Classes:

EncoderWrapper

This is a wrapper of any ASR transformer encoder.

TransformerASR

This is an implementation of transformer model for ASR.

Reference

class speechbrain.lobes.models.transformer.TransformerASR.TransformerASR(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding=True, normalize_before=False, kernel_size: Optional[int] = 31, bias: Optional[bool] = True, encoder_module: Optional[str] = 'transformer', conformer_activation: Optional[torch.nn.modules.module.Module] = <class 'speechbrain.nnet.activations.Swish'>)[source]

Bases: speechbrain.lobes.models.transformer.Transformer.TransformerInterface

This is an implementation of transformer model for ASR.

The architecture is based on the paper “Attention Is All You Need”: https://arxiv.org/pdf/1706.03762.pdf

Parameters
  • d_model (int) – The number of expected features in the encoder/decoder inputs (default=512).

  • nhead (int) – The number of heads in the multi-head attention models (default=8).

  • num_encoder_layers (int) – The number of sub-encoder-layers in the encoder (default=6).

  • num_decoder_layers (int) – The number of sub-decoder-layers in the decoder (default=6).

  • dim_ffn (int) – The dimension of the feedforward network model (default=2048).

  • dropout (int) – The dropout value (default=0.1).

  • activation (torch class) – The activation function of encoder/decoder intermediate layer. Recommended: relu or gelu (default=relu).

Example

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerASR(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> enc_out, dec_out = net.forward(src, tgt)
>>> enc_out.shape
torch.Size([8, 120, 512])
>>> dec_out.shape
torch.Size([8, 120, 512])
forward(src, tgt, wav_len=None, pad_idx=0)[source]
Parameters
  • src (tensor) – The sequence to the encoder (required).

  • tgt (tensor) – The sequence to the decoder (required).

  • pad_idx (int) – The index for <pad> token (default=0).

make_masks(src, tgt, wav_len=None, pad_idx=0)[source]

This method generates the masks for training the transformer model.

Parameters
  • src (tensor) – The sequence to the encoder (required).

  • tgt (tensor) – The sequence to the decoder (required).

  • pad_idx (int) – The index for <pad> token (default=0).

decode(tgt, encoder_out)[source]

This method implements a decoding step for the transformer model.

Parameters
  • tgt (tensor) – The sequence to the decoder (required).

  • encoder_out (tensor) – Hidden output of the encoder (required).

encode(src, wav_len=None)[source]

forward the encoder with source input

Parameters

src (tensor) – The sequence to the encoder (required).

training: bool
class speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper(transformer, *args, **kwargs)[source]

Bases: torch.nn.modules.module.Module

This is a wrapper of any ASR transformer encoder. By default, the TransformerASR .forward() function encodes and decodes. With this wrapper the .forward() function becomes .encode() only.

Important: The TransformerASR class must contain a .encode() function.

Parameters

transformer (sb.lobes.models.TransformerInterface) – A Transformer instance that contains a .encode() function.

Example

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerASR(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> encoder = EncoderWrapper(net)
>>> enc_out = encoder(src)
>>> enc_out.shape
torch.Size([8, 120, 512])
forward(x, wav_lens=None)[source]
training: bool