Source code for speechbrain.lobes.models.transformer.TransformerST
"""Transformer for ST in the SpeechBrain sytle.
Authors
* YAO FEI, CHENG 2021
"""
import torch # noqa 42
import logging
from torch import nn
from typing import Optional
from speechbrain.nnet.containers import ModuleList
from speechbrain.lobes.models.transformer.Transformer import (
get_lookahead_mask,
get_key_padding_mask,
NormalizedEmbedding,
TransformerDecoder,
TransformerEncoder,
)
from speechbrain.lobes.models.transformer.Conformer import ConformerEncoder
from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
from speechbrain.nnet.activations import Swish
logger = logging.getLogger(__name__)
[docs]
class TransformerST(TransformerASR):
"""This is an implementation of transformer model for ST.
The architecture is based on the paper "Attention Is All You Need":
https://arxiv.org/pdf/1706.03762.pdf
Arguments
----------
tgt_vocab: int
Size of vocabulary.
input_size: int
Input feature size.
d_model : int, optional
Embedding dimension size.
(default=512).
nhead : int, optional
The number of heads in the multi-head attention models (default=8).
num_encoder_layers : int, optional
The number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers : int, optional
The number of sub-decoder-layers in the decoder (default=6).
dim_ffn : int, optional
The dimension of the feedforward network model (default=2048).
dropout : int, optional
The dropout value (default=0.1).
activation : torch.nn.Module, optional
The activation function of FFN layers.
Recommended: relu or gelu (default=relu).
positional_encoding: str, optional
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
normalize_before: bool, optional
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
Defaults to True as this was shown to lead to better performance and training stability.
kernel_size: int, optional
Kernel size in convolutional layers when Conformer is used.
bias: bool, optional
Whether to use bias in Conformer convolutional layers.
encoder_module: str, optional
Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.
conformer_activation: torch.nn.Module, optional
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
attention_type: str, optional
Type of attention layer used in all Transformer or Conformer layers.
e.g. regularMHA or RelPosMHA.
max_length: int, optional
Max length for the target and source sequence in input.
Used for positional encodings.
causal: bool, optional
Whether the encoder should be causal or not (the decoder is always causal).
If causal the Conformer convolutional layer is causal.
ctc_weight: float
The weight of ctc for asr task
asr_weight: float
The weight of asr task for calculating loss
mt_weight: float
The weight of mt task for calculating loss
asr_tgt_vocab: int
The size of the asr target language
mt_src_vocab: int
The size of the mt source language
Example
-------
>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerST(
... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU,
... ctc_weight=1, asr_weight=0.3,
... )
>>> enc_out, dec_out = net.forward(src, tgt)
>>> enc_out.shape
torch.Size([8, 120, 512])
>>> dec_out.shape
torch.Size([8, 120, 512])
"""
def __init__(
self,
tgt_vocab,
input_size,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ffn=2048,
dropout=0.1,
activation=nn.ReLU,
positional_encoding="fixed_abs_sine",
normalize_before=False,
kernel_size: Optional[int] = 31,
bias: Optional[bool] = True,
encoder_module: Optional[str] = "transformer",
conformer_activation: Optional[nn.Module] = Swish,
attention_type: Optional[str] = "regularMHA",
max_length: Optional[int] = 2500,
causal: Optional[bool] = True,
ctc_weight: float = 0.0,
asr_weight: float = 0.0,
mt_weight: float = 0.0,
asr_tgt_vocab: int = 0,
mt_src_vocab: int = 0,
):
super().__init__(
tgt_vocab=tgt_vocab,
input_size=input_size,
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ffn=d_ffn,
dropout=dropout,
activation=activation,
positional_encoding=positional_encoding,
normalize_before=normalize_before,
kernel_size=kernel_size,
bias=bias,
encoder_module=encoder_module,
conformer_activation=conformer_activation,
attention_type=attention_type,
max_length=max_length,
causal=causal,
)
if ctc_weight < 1 and asr_weight > 0:
self.asr_decoder = TransformerDecoder(
num_layers=num_decoder_layers,
nhead=nhead,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
causal=True,
attention_type="regularMHA", # always use regular attention in decoder
)
self.custom_asr_tgt_module = ModuleList(
NormalizedEmbedding(d_model, asr_tgt_vocab)
)
if mt_weight > 0:
self.custom_mt_src_module = ModuleList(
NormalizedEmbedding(d_model, mt_src_vocab)
)
if encoder_module == "transformer":
self.mt_encoder = TransformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
causal=self.causal,
attention_type=self.attention_type,
)
elif encoder_module == "conformer":
self.mt_encoder = ConformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=conformer_activation,
kernel_size=kernel_size,
bias=bias,
causal=self.causal,
attention_type=self.attention_type,
)
assert (
normalize_before
), "normalize_before must be True for Conformer"
assert (
conformer_activation is not None
), "conformer_activation must not be None"
# reset parameters using xavier_normal_
self._init_params()
[docs]
def forward_asr(self, encoder_out, src, tgt, wav_len, pad_idx=0):
"""This method implements a decoding step for asr task
Arguments
----------
encoder_out : tensor
The representation of the encoder (required).
tgt (transcription): tensor
The sequence to the decoder (required).
pad_idx : int
The index for <pad> token (default=0).
"""
# reshpae the src vector to [Batch, Time, Fea] is a 4d vector is given
if src.dim() == 4:
bz, t, ch1, ch2 = src.shape
src = src.reshape(bz, t, ch1 * ch2)
(
src_key_padding_mask,
tgt_key_padding_mask,
src_mask,
tgt_mask,
) = self.make_masks(src, tgt, wav_len, pad_idx=pad_idx)
transcription = self.custom_asr_tgt_module(tgt)
if self.attention_type == "RelPosMHAXL":
transcription = transcription + self.positional_encoding_decoder(
transcription
)
elif self.attention_type == "fixed_abs_sine":
transcription = transcription + self.positional_encoding(
transcription
)
asr_decoder_out, _, _ = self.asr_decoder(
tgt=transcription,
memory=encoder_out,
memory_mask=src_mask,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
)
return asr_decoder_out
[docs]
def forward_mt(self, src, tgt, pad_idx=0):
"""This method implements a forward step for mt task
Arguments
----------
src (transcription): tensor
The sequence to the encoder (required).
tgt (translation): tensor
The sequence to the decoder (required).
pad_idx : int
The index for <pad> token (default=0).
"""
(
src_key_padding_mask,
tgt_key_padding_mask,
src_mask,
tgt_mask,
) = self.make_masks_for_mt(src, tgt, pad_idx=pad_idx)
src = self.custom_mt_src_module(src)
if self.attention_type == "RelPosMHAXL":
pos_embs_encoder = self.positional_encoding(src)
elif self.positional_encoding_type == "fixed_abs_sine":
src = src + self.positional_encoding(src)
pos_embs_encoder = None
encoder_out, _ = self.mt_encoder(
src=src,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs_encoder,
)
tgt = self.custom_tgt_module(tgt)
if self.attention_type == "RelPosMHAXL":
# use standard sinusoidal pos encoding in decoder
tgt = tgt + self.positional_encoding_decoder(tgt)
src = src + self.positional_encoding_decoder(src)
elif self.positional_encoding_type == "fixed_abs_sine":
tgt = tgt + self.positional_encoding(tgt)
decoder_out, _, _ = self.decoder(
tgt=tgt,
memory=encoder_out,
memory_mask=src_mask,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
)
return encoder_out, decoder_out
[docs]
def forward_mt_decoder_only(self, src, tgt, pad_idx=0):
"""This method implements a forward step for mt task using a wav2vec encoder
(same than above, but without the encoder stack)
Arguments
----------
src (transcription): tensor
output features from the w2v2 encoder
tgt (translation): tensor
The sequence to the decoder (required).
pad_idx : int
The index for <pad> token (default=0).
"""
(
src_key_padding_mask,
tgt_key_padding_mask,
src_mask,
tgt_mask,
) = self.make_masks_for_mt(src, tgt, pad_idx=pad_idx)
tgt = self.custom_tgt_module(tgt)
if self.attention_type == "RelPosMHAXL":
# use standard sinusoidal pos encoding in decoder
tgt = tgt + self.positional_encoding_decoder(tgt)
elif self.positional_encoding_type == "fixed_abs_sine":
tgt = tgt + self.positional_encoding(tgt)
decoder_out, _, multihead = self.decoder(
tgt=tgt,
memory=src,
memory_mask=src_mask,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
)
return decoder_out
[docs]
def decode_asr(self, tgt, encoder_out):
"""This method implements a decoding step for the transformer model.
Arguments
---------
tgt : torch.Tensor
The sequence to the decoder.
encoder_out : torch.Tensor
Hidden output of the encoder.
"""
tgt_mask = get_lookahead_mask(tgt)
tgt = self.custom_tgt_module(tgt)
if self.attention_type == "RelPosMHAXL":
# we use fixed positional encodings in the decoder
tgt = tgt + self.positional_encoding_decoder(tgt)
encoder_out = encoder_out + self.positional_encoding_decoder(
encoder_out
)
elif self.positional_encoding_type == "fixed_abs_sine":
tgt = tgt + self.positional_encoding(tgt) # add the encodings here
prediction, _, multihead_attns = self.asr_decoder(
tgt, encoder_out, tgt_mask=tgt_mask,
)
return prediction, multihead_attns[-1]
[docs]
def make_masks_for_mt(self, src, tgt, pad_idx=0):
"""This method generates the masks for training the transformer model.
Arguments
---------
src : tensor
The sequence to the encoder (required).
tgt : tensor
The sequence to the decoder (required).
pad_idx : int
The index for <pad> token (default=0).
"""
src_key_padding_mask = None
if self.training:
src_key_padding_mask = get_key_padding_mask(src, pad_idx=pad_idx)
tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx)
src_mask = None
tgt_mask = get_lookahead_mask(tgt)
return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask