Source code for speechbrain.lobes.models.transformer.Transformer
"""Transformer implementation in the SpeechBrain style.
Authors
* Jianyuan Zhong 2020
* Samuele Cornell 2021
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import speechbrain as sb
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import RelPosEncXL
from speechbrain.nnet.CNN import Conv1d
from speechbrain.utils.checkpoints import map_old_state_dict_weights
from .Branchformer import BranchformerEncoder
from .Conformer import ConformerEncoder
[docs]
class TransformerInterface(nn.Module):
"""This is an interface for transformer model.
Users can modify the attributes and define the forward function as
needed according to their own tasks.
The architecture is based on the paper "Attention Is All You Need":
https://arxiv.org/pdf/1706.03762.pdf
Arguments
---------
d_model: int
The number of expected features in the encoder/decoder inputs (default=512).
nhead: int
The number of heads in the multi-head attention models (default=8).
num_encoder_layers: int, optional
The number of encoder layers in1ì the encoder.
num_decoder_layers: int, optional
The number of decoder layers in the decoder.
d_ffn: int, optional
The dimension of the feedforward network model hidden layer.
dropout: int, optional
The dropout value.
activation: torch.nn.Module, optional
The activation function for Feed-Forward Network layer,
e.g., relu or gelu or swish.
custom_src_module: torch.nn.Module, optional
Module that processes the src features to expected feature dim.
custom_tgt_module: torch.nn.Module, optional
Module that processes the src features to expected feature dim.
positional_encoding: str, optional
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
normalize_before: bool, optional
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
Defaults to True as this was shown to lead to better performance and training stability.
kernel_size: int, optional
Kernel size in convolutional layers when Conformer is used.
bias: bool, optional
Whether to use bias in Conformer convolutional layers.
encoder_module: str, optional
Choose between Branchformer, Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.
conformer_activation: torch.nn.Module, optional
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
branchformer_activation: torch.nn.Module, optional
Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.
attention_type: str, optional
Type of attention layer used in all Transformer or Conformer layers.
e.g. regularMHA or RelPosMHA.
max_length: int, optional
Max length for the target and source sequence in input.
Used for positional encodings.
causal: bool, optional
Whether the encoder should be causal or not (the decoder is always causal).
If causal the Conformer convolutional layer is causal.
encoder_kdim: int, optional
Dimension of the key for the encoder.
encoder_vdim: int, optional
Dimension of the value for the encoder.
decoder_kdim: int, optional
Dimension of the key for the decoder.
decoder_vdim: int, optional
Dimension of the value for the decoder.
csgu_linear_units: int, optional
Number of neurons in the hidden linear units of the CSGU Module.
-> Branchformer
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
-> Branchformer
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2.
-> Branchformer
output_hidden_states: bool, optional
Whether the model should output the hidden states as a list of tensor.
layerdrop_prob: float
The probability to drop an entire layer.
"""
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ffn=2048,
dropout=0.1,
activation=nn.ReLU,
custom_src_module=None,
custom_tgt_module=None,
positional_encoding="fixed_abs_sine",
normalize_before=True,
kernel_size: Optional[int] = 31,
bias: Optional[bool] = True,
encoder_module: Optional[str] = "transformer",
conformer_activation: Optional[nn.Module] = Swish,
branchformer_activation: Optional[nn.Module] = nn.GELU,
attention_type: Optional[str] = "regularMHA",
max_length: Optional[int] = 2500,
causal: Optional[bool] = False,
encoder_kdim: Optional[int] = None,
encoder_vdim: Optional[int] = None,
decoder_kdim: Optional[int] = None,
decoder_vdim: Optional[int] = None,
csgu_linear_units: Optional[int] = 3072,
gate_activation: Optional[nn.Module] = nn.Identity,
use_linear_after_conv: Optional[bool] = False,
output_hidden_states=False,
layerdrop_prob=0.0,
):
super().__init__()
self.causal = causal
self.attention_type = attention_type
self.positional_encoding_type = positional_encoding
self.encoder_kdim = encoder_kdim
self.encoder_vdim = encoder_vdim
self.decoder_kdim = decoder_kdim
self.decoder_vdim = decoder_vdim
self.output_hidden_states = output_hidden_states
self.layerdrop_prob = layerdrop_prob
assert attention_type in ["regularMHA", "RelPosMHAXL", "hypermixing"]
assert positional_encoding in ["fixed_abs_sine", None]
assert (
num_encoder_layers + num_decoder_layers > 0
), "number of encoder layers and number of decoder layers cannot both be 0!"
if positional_encoding == "fixed_abs_sine":
self.positional_encoding = PositionalEncoding(d_model, max_length)
elif positional_encoding is None:
pass
# no positional encodings
# overrides any other pos_embedding
if attention_type == "RelPosMHAXL":
self.positional_encoding = RelPosEncXL(d_model)
self.positional_encoding_decoder = PositionalEncoding(
d_model, max_length
)
# initialize the encoder
if num_encoder_layers > 0:
if custom_src_module is not None:
self.custom_src_module = custom_src_module(d_model)
if encoder_module == "transformer":
self.encoder = TransformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
causal=self.causal,
attention_type=self.attention_type,
kdim=self.encoder_kdim,
vdim=self.encoder_vdim,
output_hidden_states=self.output_hidden_states,
layerdrop_prob=self.layerdrop_prob,
)
elif encoder_module == "conformer":
self.encoder = ConformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=conformer_activation,
kernel_size=kernel_size,
bias=bias,
causal=self.causal,
attention_type=self.attention_type,
output_hidden_states=self.output_hidden_states,
layerdrop_prob=self.layerdrop_prob,
)
assert (
normalize_before
), "normalize_before must be True for Conformer"
assert (
conformer_activation is not None
), "conformer_activation must not be None"
elif encoder_module == "branchformer":
self.encoder = BranchformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_model=d_model,
dropout=dropout,
activation=branchformer_activation,
kernel_size=kernel_size,
attention_type=self.attention_type,
csgu_linear_units=csgu_linear_units,
gate_activation=gate_activation,
use_linear_after_conv=use_linear_after_conv,
output_hidden_states=self.output_hidden_states,
layerdrop_prob=self.layerdrop_prob,
)
# initialize the decoder
if num_decoder_layers > 0:
if custom_tgt_module is not None:
self.custom_tgt_module = custom_tgt_module(d_model)
self.decoder = TransformerDecoder(
num_layers=num_decoder_layers,
nhead=nhead,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
causal=True,
attention_type="regularMHA", # always use regular attention in decoder
kdim=self.decoder_kdim,
vdim=self.decoder_vdim,
)
[docs]
def forward(self, **kwags):
"""Users should modify this function according to their own tasks."""
raise NotImplementedError
[docs]
class PositionalEncoding(nn.Module):
"""This class implements the absolute sinusoidal positional encoding function.
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
Arguments
---------
input_size: int
Embedding dimension.
max_len : int, optional
Max length of the input sequences (default 2500).
Example
-------
>>> a = torch.rand((8, 120, 512))
>>> enc = PositionalEncoding(input_size=a.shape[-1])
>>> b = enc(a)
>>> b.shape
torch.Size([1, 120, 512])
"""
def __init__(self, input_size, max_len=2500):
super().__init__()
if input_size % 2 != 0:
raise ValueError(
f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
)
self.max_len = max_len
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
denominator = torch.exp(
torch.arange(0, input_size, 2).float()
* -(math.log(10000.0) / input_size)
)
pe[:, 0::2] = torch.sin(positions * denominator)
pe[:, 1::2] = torch.cos(positions * denominator)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
[docs]
def forward(self, x):
"""
Arguments
---------
x : torch.Tensor
Input feature shape (batch, time, fea)
Returns
-------
The positional encoding.
"""
return self.pe[:, : x.size(1)].clone().detach()
[docs]
class TransformerEncoderLayer(nn.Module):
"""This is an implementation of self-attention encoder layer.
Arguments
---------
d_ffn: int, optional
The dimension of the feedforward network model hidden layer.
nhead: int
The number of heads in the multi-head attention models (default=8).
d_model: int
The number of expected features in the encoder/decoder inputs (default=512).
kdim: int, optional
Dimension of the key.
vdim: int, optional
Dimension of the value.
dropout: int, optional
The dropout value.
activation: torch.nn.Module, optional
The activation function for Feed-Forward Network layer,
e.g., relu or gelu or swish.
normalize_before: bool, optional
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
Defaults to True as this was shown to lead to better performance and training stability.
attention_type: str, optional
Type of attention layer used in all Transformer or Conformer layers.
e.g. regularMHA or RelPosMHA.
ffn_type: str
type of ffn: regularFFN/1dcnn
ffn_cnn_kernel_size_list: list of int
kernel size of 2 1d-convs if ffn_type is 1dcnn
causal: bool, optional
Whether the encoder should be causal or not (the decoder is always causal).
If causal the Conformer convolutional layer is causal.
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoderLayer(512, 8, d_model=512)
>>> output = net(x)
>>> output[0].shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
d_ffn,
nhead,
d_model,
kdim=None,
vdim=None,
dropout=0.0,
activation=nn.ReLU,
normalize_before=False,
attention_type="regularMHA",
ffn_type="regularFFN",
ffn_cnn_kernel_size_list=[3, 3],
causal=False,
):
super().__init__()
if attention_type == "regularMHA":
self.self_att = sb.nnet.attention.MultiheadAttention(
nhead=nhead,
d_model=d_model,
dropout=dropout,
kdim=kdim,
vdim=vdim,
)
elif attention_type == "RelPosMHAXL":
self.self_att = sb.nnet.attention.RelPosMHAXL(
d_model, nhead, dropout, mask_pos_future=causal
)
elif attention_type == "hypermixing":
self.self_att = sb.nnet.hypermixing.HyperMixing(
input_output_dim=d_model,
hypernet_size=d_ffn,
tied=False,
num_heads=nhead,
fix_tm_hidden_size=False,
)
if ffn_type == "regularFFN":
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
)
elif ffn_type == "1dcnn":
self.pos_ffn = nn.Sequential(
Conv1d(
in_channels=d_model,
out_channels=d_ffn,
kernel_size=ffn_cnn_kernel_size_list[0],
padding="causal" if causal else "same",
),
nn.ReLU(),
Conv1d(
in_channels=d_ffn,
out_channels=d_model,
kernel_size=ffn_cnn_kernel_size_list[1],
padding="causal" if causal else "same",
),
)
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.normalize_before = normalize_before
self.pos_ffn_type = ffn_type
[docs]
def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
):
"""
Arguments
---------
src : torch.Tensor
The sequence to the encoder layer.
src_mask : torch.Tensor
The mask for the src query for each example in the batch.
src_key_padding_mask : torch.Tensor, optional
The mask for the src keys for each example in the batch.
pos_embs: torch.Tensor, optional
The positional embeddings tensor.
Returns
-------
output : torch.Tensor
The output of the transformer encoder layer.
"""
if self.normalize_before:
src1 = self.norm1(src)
else:
src1 = src
output, self_attn = self.self_att(
src1,
src1,
src1,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
)
# add & norm
src = src + self.dropout1(output)
if not self.normalize_before:
src = self.norm1(src)
if self.normalize_before:
src1 = self.norm2(src)
else:
src1 = src
output = self.pos_ffn(src1)
# add & norm
output = src + self.dropout2(output)
if not self.normalize_before:
output = self.norm2(output)
return output, self_attn
[docs]
class TransformerEncoder(nn.Module):
"""This class implements the transformer encoder.
Arguments
---------
num_layers : int
Number of transformer layers to include.
nhead : int
Number of attention heads.
d_ffn : int
Hidden size of self-attention Feed Forward layer.
input_shape : tuple
Expected shape of the input.
d_model : int
The dimension of the input embedding.
kdim : int
Dimension for key (Optional).
vdim : int
Dimension for value (Optional).
dropout : float
Dropout for the encoder (Optional).
activation: torch.nn.Module, optional
The activation function for Feed-Forward Network layer,
e.g., relu or gelu or swish.
normalize_before: bool, optional
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
Defaults to True as this was shown to lead to better performance and training stability.
causal: bool, optional
Whether the encoder should be causal or not (the decoder is always causal).
If causal the Conformer convolutional layer is causal.
layerdrop_prob: float
The probability to drop an entire layer
attention_type: str, optional
Type of attention layer used in all Transformer or Conformer layers.
e.g. regularMHA or RelPosMHA.
ffn_type: str
type of ffn: regularFFN/1dcnn
ffn_cnn_kernel_size_list: list of int
conv kernel size of 2 1d-convs if ffn_type is 1dcnn
output_hidden_states: bool, optional
Whether the model should output the hidden states as a list of tensor.
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(1, 8, 512, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(1, 8, 512, d_model=512, output_hidden_states=True)
>>> output, attn_list, hidden_list = net(x)
>>> hidden_list[0].shape
torch.Size([8, 60, 512])
>>> len(hidden_list)
2
"""
def __init__(
self,
num_layers,
nhead,
d_ffn,
input_shape=None,
d_model=None,
kdim=None,
vdim=None,
dropout=0.0,
activation=nn.ReLU,
normalize_before=False,
causal=False,
layerdrop_prob=0.0,
attention_type="regularMHA",
ffn_type="regularFFN",
ffn_cnn_kernel_size_list=[3, 3],
output_hidden_states=False,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
TransformerEncoderLayer(
d_ffn=d_ffn,
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
causal=causal,
attention_type=attention_type,
ffn_type=ffn_type,
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
)
for i in range(num_layers)
]
)
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.layerdrop_prob = layerdrop_prob
self.output_hidden_states = output_hidden_states
[docs]
def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
dynchunktrain_config=None,
):
"""
Arguments
---------
src : torch.Tensor
The sequence to the encoder layer (required).
src_mask : torch.Tensor
The mask for the src sequence (optional).
src_key_padding_mask : torch.Tensor
The mask for the src keys per batch (optional).
pos_embs : torch.Tensor
The positional embedding tensor
dynchunktrain_config : config
Not supported for this encoder.
Returns
-------
output : torch.Tensor
The output of the transformer.
attention_lst : list
The attention values.
hidden_state_lst : list, optional
The output of the hidden layers of the encoder.
Only works if output_hidden_states is set to true.
"""
assert (
dynchunktrain_config is None
), "Dynamic Chunk Training unsupported for this encoder"
output = src
if self.layerdrop_prob > 0.0:
keep_probs = torch.rand(len(self.layers))
attention_lst = []
if self.output_hidden_states:
hidden_state_lst = [output]
for i, enc_layer in enumerate(self.layers):
if (
not self.training
or self.layerdrop_prob == 0.0
or keep_probs[i] > self.layerdrop_prob
):
output, attention = enc_layer(
output,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
)
attention_lst.append(attention)
if self.output_hidden_states:
hidden_state_lst.append(output)
output = self.norm(output)
if self.output_hidden_states:
return output, attention_lst, hidden_state_lst
return output, attention_lst
[docs]
class TransformerDecoderLayer(nn.Module):
"""This class implements the self-attention decoder layer.
Arguments
---------
d_ffn : int
Hidden size of self-attention Feed Forward layer.
nhead : int
Number of attention heads.
d_model : int
Dimension of the model.
kdim : int
Dimension for key (optional).
vdim : int
Dimension for value (optional).
dropout : float
Dropout for the decoder (optional).
activation : Callable
Function to use between layers, default nn.ReLU
normalize_before : bool
Whether to normalize before layers.
attention_type : str
Type of attention to use, "regularMHA" or "RelPosMHAXL"
causal : bool
Whether to mask future positions.
Example
-------
>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoderLayer(1024, 8, d_model=512)
>>> output, self_attn, multihead_attn = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
d_ffn,
nhead,
d_model,
kdim=None,
vdim=None,
dropout=0.0,
activation=nn.ReLU,
normalize_before=False,
attention_type="regularMHA",
causal=None,
):
super().__init__()
self.nhead = nhead
if attention_type == "regularMHA":
self.self_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
)
self.multihead_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
)
elif attention_type == "RelPosMHAXL":
self.self_attn = sb.nnet.attention.RelPosMHAXL(
d_model, nhead, dropout, mask_pos_future=causal
)
self.multihead_attn = sb.nnet.attention.RelPosMHAXL(
d_model, nhead, dropout, mask_pos_future=causal
)
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
)
# normalization layers
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.dropout3 = torch.nn.Dropout(dropout)
self.normalize_before = normalize_before
[docs]
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
pos_embs_tgt=None,
pos_embs_src=None,
):
"""
Arguments
----------
tgt: torch.Tensor
The sequence to the decoder layer (required).
memory: torch.Tensor
The sequence from the last layer of the encoder (required).
tgt_mask: torch.Tensor
The mask for the tgt sequence (optional).
memory_mask: torch.Tensor
The mask for the memory sequence (optional).
tgt_key_padding_mask: torch.Tensor
The mask for the tgt keys per batch (optional).
memory_key_padding_mask: torch.Tensor
The mask for the memory keys per batch (optional).
pos_embs_tgt: torch.Tensor
The positional embeddings for the target (optional).
pos_embs_src: torch.Tensor
The positional embeddings for the source (optional).
"""
if self.normalize_before:
tgt1 = self.norm1(tgt)
else:
tgt1 = tgt
# self-attention over the target sequence
tgt2, self_attn = self.self_attn(
query=tgt1,
key=tgt1,
value=tgt1,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
pos_embs=pos_embs_tgt,
)
# add & norm
tgt = tgt + self.dropout1(tgt2)
if not self.normalize_before:
tgt = self.norm1(tgt)
if self.normalize_before:
tgt1 = self.norm2(tgt)
else:
tgt1 = tgt
# multi-head attention over the target sequence and encoder states
tgt2, multihead_attention = self.multihead_attn(
query=tgt1,
key=memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
pos_embs=pos_embs_src,
)
# add & norm
tgt = tgt + self.dropout2(tgt2)
if not self.normalize_before:
tgt = self.norm2(tgt)
if self.normalize_before:
tgt1 = self.norm3(tgt)
else:
tgt1 = tgt
tgt2 = self.pos_ffn(tgt1)
# add & norm
tgt = tgt + self.dropout3(tgt2)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt, self_attn, multihead_attention
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Load the model from a state_dict and map the old keys to the new keys."""
mapping = {"mutihead_attention": "multihead_attention"}
state_dict = map_old_state_dict_weights(state_dict, mapping)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
[docs]
class TransformerDecoder(nn.Module):
"""This class implements the Transformer decoder.
Arguments
---------
num_layers : int
Number of transformer layers for the decoder.
nhead : int
Number of attention heads.
d_ffn : int
Hidden size of self-attention Feed Forward layer.
d_model : int
Dimension of the model.
kdim : int, optional
Dimension for key (Optional).
vdim : int, optional
Dimension for value (Optional).
dropout : float, optional
Dropout for the decoder (Optional).
activation : Callable
The function to apply between layers, default nn.ReLU
normalize_before : bool
Whether to normalize before layers.
causal : bool
Whether to allow future information in decoding.
attention_type : str
Type of attention to use, "regularMHA" or "RelPosMHAXL"
Example
-------
>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoder(1, 8, 1024, d_model=512)
>>> output, _, _ = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
num_layers,
nhead,
d_ffn,
d_model,
kdim=None,
vdim=None,
dropout=0.0,
activation=nn.ReLU,
normalize_before=False,
causal=False,
attention_type="regularMHA",
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
TransformerDecoderLayer(
d_ffn=d_ffn,
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
causal=causal,
attention_type=attention_type,
)
for _ in range(num_layers)
]
)
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
[docs]
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
pos_embs_tgt=None,
pos_embs_src=None,
):
"""
Arguments
----------
tgt : torch.Tensor
The sequence to the decoder layer (required).
memory : torch.Tensor
The sequence from the last layer of the encoder (required).
tgt_mask : torch.Tensor
The mask for the tgt sequence (optional).
memory_mask : torch.Tensor
The mask for the memory sequence (optional).
tgt_key_padding_mask : torch.Tensor
The mask for the tgt keys per batch (optional).
memory_key_padding_mask : torch.Tensor
The mask for the memory keys per batch (optional).
pos_embs_tgt : torch.Tensor
The positional embeddings for the target (optional).
pos_embs_src : torch.Tensor
The positional embeddings for the source (optional).
"""
output = tgt
self_attns, multihead_attns = [], []
for dec_layer in self.layers:
output, self_attn, multihead_attn = dec_layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos_embs_tgt=pos_embs_tgt,
pos_embs_src=pos_embs_src,
)
self_attns.append(self_attn)
multihead_attns.append(multihead_attn)
output = self.norm(output)
return output, self_attns, multihead_attns
[docs]
class NormalizedEmbedding(nn.Module):
"""This class implements the normalized embedding layer for the transformer.
Since the dot product of the self-attention is always normalized by sqrt(d_model)
and the final linear projection for prediction shares weight with the embedding layer,
we multiply the output of the embedding by sqrt(d_model).
Arguments
---------
d_model: int
The number of expected features in the encoder/decoder inputs (default=512).
vocab: int
The vocab size.
Example
-------
>>> emb = NormalizedEmbedding(512, 1000)
>>> trg = torch.randint(0, 999, (8, 50))
>>> emb_fea = emb(trg)
"""
def __init__(self, d_model, vocab):
super().__init__()
self.emb = sb.nnet.embedding.Embedding(
num_embeddings=vocab, embedding_dim=d_model, blank_id=0
)
self.d_model = d_model
[docs]
def forward(self, x):
"""Processes the input tensor x and returns an output tensor."""
return self.emb(x) * math.sqrt(self.d_model)
[docs]
def get_key_padding_mask(padded_input, pad_idx):
"""Creates a binary mask to prevent attention to padded locations.
We suggest using ``get_mask_from_lengths`` instead of this function.
Arguments
---------
padded_input: torch.Tensor
Padded input.
pad_idx: int
idx for padding element.
Returns
-------
key_padded_mask: torch.Tensor
Binary mask to prevent attention to padding.
Example
-------
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_key_padding_mask(a, pad_idx=0)
tensor([[False, False, True],
[False, False, True],
[False, False, True]])
"""
if len(padded_input.shape) == 4:
bz, time, ch1, ch2 = padded_input.shape
padded_input = padded_input.reshape(bz, time, ch1 * ch2)
key_padded_mask = padded_input.eq(pad_idx).to(padded_input.device)
# if the input is more than 2d, mask the locations where they are silence
# across all channels
if len(padded_input.shape) > 2:
key_padded_mask = key_padded_mask.float().prod(dim=-1).bool()
return key_padded_mask.detach()
return key_padded_mask.detach()
[docs]
def get_lookahead_mask(padded_input):
"""Creates a binary mask for each sequence which masks future frames.
Arguments
---------
padded_input: torch.Tensor
Padded input tensor.
Returns
-------
mask : torch.Tensor
Binary mask for masking future frames.
Example
-------
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_lookahead_mask(a)
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0.]])
"""
seq_len = padded_input.shape[1]
mask = (
torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device))
== 1
).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask.detach().to(padded_input.device)
[docs]
def get_mask_from_lengths(lengths, max_len=None):
"""Creates a binary mask from sequence lengths
Arguments
---------
lengths: torch.Tensor
A tensor of sequence lengths
max_len: int (Optional)
Maximum sequence length, defaults to None.
Returns
-------
mask: torch.Tensor
the mask where padded elements are set to True.
Then one can use tensor.masked_fill_(mask, 0) for the masking.
Example
-------
>>> lengths = torch.tensor([3, 2, 4])
>>> get_mask_from_lengths(lengths)
tensor([[False, False, False, True],
[False, False, True, True],
[False, False, False, False]])
"""
if max_len is None:
max_len = torch.max(lengths).item()
seq_range = torch.arange(
max_len, device=lengths.device, dtype=lengths.dtype
)
return ~(seq_range.unsqueeze(0) < lengths.unsqueeze(1))