Source code for speechbrain.lobes.models.transformer.Transformer
"""Transformer implementation in the SpeechBrain sytle.
Authors
* Jianyuan Zhong 2020
"""
import math
import torch
import torch.nn as nn
import speechbrain as sb
from typing import Optional
from .conformer import ConformerEncoder
from speechbrain.nnet.activations import Swish
[docs]class TransformerInterface(nn.Module):
"""This is an interface for transformer model.
Users can modify the attributes and define the forward function as
needed according to their own tasks.
The architecture is based on the paper "Attention Is All You Need":
https://arxiv.org/pdf/1706.03762.pdf
Arguments
----------
d_model : int
The number of expected features in the encoder/decoder inputs (default=512).
nhead : int
The number of heads in the multi-head attention models (default=8).
num_encoder_layers : int
The number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers : int
The number of sub-decoder-layers in the decoder (default=6).
dim_ffn : int
The dimension of the feedforward network model (default=2048).
dropout : int
The dropout value (default=0.1).
activation : torch class
The activation function of encoder/decoder intermediate layer,
e.g., relu or gelu (default=relu)
custom_src_module : torch class
Module that processes the src features to expected feature dim.
custom_tgt_module : torch class
Module that processes the src features to expected feature dim.
"""
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ffn=2048,
dropout=0.1,
activation=nn.ReLU,
custom_src_module=None,
custom_tgt_module=None,
positional_encoding=True,
normalize_before=False,
kernel_size: Optional[int] = 31,
bias: Optional[bool] = True,
encoder_module: Optional[str] = "transformer",
conformer_activation: Optional[nn.Module] = Swish,
):
super().__init__()
assert (
num_encoder_layers + num_decoder_layers > 0
), "number of encoder layers and number of decoder layers cannot both be 0!"
if positional_encoding:
self.positional_encoding = PositionalEncoding(d_model)
# initialize the encoder
if num_encoder_layers > 0:
if custom_src_module is not None:
self.custom_src_module = custom_src_module(d_model)
if encoder_module == "transformer":
self.encoder = TransformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
)
elif encoder_module == "conformer":
self.encoder = ConformerEncoder(
nhead=nhead,
num_layers=num_encoder_layers,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=conformer_activation,
kernel_size=kernel_size,
bias=bias,
)
assert (
normalize_before
), "normalize_before must be True for Conformer"
assert (
conformer_activation is not None
), "conformer_activation must not be None"
# initialize the decoder
if num_decoder_layers > 0:
if custom_tgt_module is not None:
self.custom_tgt_module = custom_tgt_module(d_model)
self.decoder = TransformerDecoder(
num_layers=num_decoder_layers,
nhead=nhead,
d_ffn=d_ffn,
d_model=d_model,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
)
[docs] def forward(self, **kwags):
"""Users should modify this function according to their own tasks.
"""
raise NotImplementedError
[docs]class PositionalEncoding(nn.Module):
"""This class implements the positional encoding function.
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
Arguments
---------
max_len : int
Max length of the input sequences (default 2500).
Example
-------
>>> a = torch.rand((8, 120, 512))
>>> enc = PositionalEncoding(input_size=a.shape[-1])
>>> b = enc(a)
>>> b.shape
torch.Size([1, 120, 512])
"""
def __init__(self, input_size, max_len=2500):
super().__init__()
self.max_len = max_len
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
denominator = torch.exp(
torch.arange(0, input_size, 2).float()
* -(math.log(10000.0) / input_size)
)
pe[:, 0::2] = torch.sin(positions * denominator)
pe[:, 1::2] = torch.cos(positions * denominator)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
[docs] def forward(self, x):
"""
Arguments
---------
x : tensor
Input feature shape (batch, time, fea)
"""
return self.pe[:, : x.size(1)].clone().detach()
[docs]class TransformerEncoderLayer(nn.Module):
"""This is an implementation of self-attention encoder layer.
Arguments
----------
d_ffn : int
Hidden size of self-attention Feed Forward layer.
nhead : int
Number of attention heads.
d_model : int
The expected size of the input embedding.
reshape : bool
Whether to automatically shape 4-d input to 3-d.
kdim : int
Dimension of the key (Optional).
vdim : int
Dimension of the value (Optional).
dropout : float
Dropout for the encoder (Optional).
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoderLayer(512, 8, d_model=512)
>>> output = net(x)
>>> output[0].shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
d_ffn,
nhead,
d_model=None,
kdim=None,
vdim=None,
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
):
super().__init__()
self.self_att = sb.nnet.attention.MultiheadAttention(
nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim,
)
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
)
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.normalize_before = normalize_before
[docs] def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
):
"""
Arguments
----------
src : tensor
The sequence to the encoder layer (required).
src_mask : tensor
The mask for the src sequence (optional).
src_key_padding_mask : tensor
The mask for the src keys per batch (optional).
"""
if self.normalize_before:
src1 = self.norm1(src)
else:
src1 = src
output, self_attn = self.self_att(
src1,
src1,
src1,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
# add & norm
src = src + self.dropout1(output)
if not self.normalize_before:
src = self.norm1(src)
if self.normalize_before:
src1 = self.norm2(src)
else:
src1 = src
output = self.pos_ffn(src1)
# add & norm
output = src + self.dropout2(output)
if not self.normalize_before:
output = self.norm2(output)
return output, self_attn
[docs]class TransformerEncoder(nn.Module):
"""This class implements the transformer encoder.
Arguments
---------
num_layers : int
Number of transformer layers to include.
nhead : int
Number of attention heads.
d_ffn : int
Hidden size of self-attention Feed Forward layer.
input_shape : tuple
Expected shape of an example input.
d_model : int
The dimension of the input embedding.
kdim : int
Dimension for key (Optional).
vdim : int
Dimension for value (Optional).
dropout : float
Dropout for the encoder (Optional).
input_module: torch class
The module to process the source input feature to expected
feature dimension (Optional).
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder(1, 8, 512, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
num_layers,
nhead,
d_ffn,
input_shape=None,
d_model=None,
kdim=None,
vdim=None,
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
):
super().__init__()
if input_shape is None and d_model is None:
raise ValueError("Expected one of input_shape or d_model")
if input_shape is not None and d_model is None:
if len(input_shape) == 3:
msg = "Input shape of the Transformer must be (batch, time, fea). Please revise the forward function in TransformerInterface to handle arbitrary shape of input."
raise ValueError(msg)
d_model = input_shape[-1]
self.layers = torch.nn.ModuleList(
[
TransformerEncoderLayer(
d_ffn=d_ffn,
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
)
for i in range(num_layers)
]
)
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
[docs] def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
):
"""
Arguments
----------
src : tensor
The sequence to the encoder layer (required).
src_mask : tensor
The mask for the src sequence (optional).
src_key_padding_mask : tensor
The mask for the src keys per batch (optional).
"""
output = src
attention_lst = []
for enc_layer in self.layers:
output, attention = enc_layer(
output,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
)
attention_lst.append(attention)
output = self.norm(output)
return output, attention_lst
[docs]class TransformerDecoderLayer(nn.Module):
"""This class implements the self-attention decoder layer.
Arguments
----------
d_ffn : int
Hidden size of self-attention Feed Forward layer.
nhead : int
Number of attention heads.
d_model : int
Dimension of the model.
kdim : int
Dimension for key (optional).
vdim : int
Dimension for value (optional).
dropout : float
Dropout for the decoder (optional).
Example
-------
>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoderLayer(1024, 8, d_model=512)
>>> output, self_attn, multihead_attn = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
d_ffn,
nhead,
d_model,
kdim=None,
vdim=None,
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
):
super().__init__()
self.self_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout,
)
self.mutihead_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout,
)
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
)
# normalization layers
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.dropout3 = torch.nn.Dropout(dropout)
self.normalize_before = normalize_before
[docs] def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
"""
Arguments
----------
tgt: tensor
The sequence to the decoder layer (required).
memory: tensor
The sequence from the last layer of the encoder (required).
tgt_mask: tensor
The mask for the tgt sequence (optional).
memory_mask: tensor
The mask for the memory sequence (optional).
tgt_key_padding_mask: tensor
The mask for the tgt keys per batch (optional).
memory_key_padding_mask: tensor
The mask for the memory keys per batch (optional).
"""
if self.normalize_before:
tgt1 = self.norm1(tgt)
else:
tgt1 = tgt
# self-attention over the target sequence
tgt2, self_attn = self.self_attn(
query=tgt1,
key=tgt1,
value=tgt1,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)
# add & norm
tgt = tgt + self.dropout1(tgt2)
if not self.normalize_before:
tgt = self.norm1(tgt)
if self.normalize_before:
tgt1 = self.norm2(tgt)
else:
tgt1 = tgt
# multi-head attention over the target sequence and encoder states
tgt2, multihead_attention = self.mutihead_attn(
query=tgt1,
key=memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
# add & norm
tgt = tgt + self.dropout2(tgt2)
if not self.normalize_before:
tgt = self.norm2(tgt)
if self.normalize_before:
tgt1 = self.norm3(tgt)
else:
tgt1 = tgt
tgt2 = self.pos_ffn(tgt1)
# add & norm
tgt = tgt + self.dropout3(tgt2)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt, self_attn, multihead_attention
[docs]class TransformerDecoder(nn.Module):
"""This class implements the Transformer decoder.
Arguments
----------
d_ffn : int
Hidden size of self-attention Feed Forward layer.
nhead : int
Number of attention heads.
d_model : int
Dimension of the model.
kdim : int
Dimension for key (Optional).
vdim : int
Dimension for value (Optional).
dropout : float
Dropout for the decoder (Optional).
Example
-------
>>> src = torch.rand((8, 60, 512))
>>> tgt = torch.rand((8, 60, 512))
>>> net = TransformerDecoder(1, 8, 1024, d_model=512)
>>> output, _, _ = net(src, tgt)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
num_layers,
nhead,
d_ffn,
d_model,
kdim=None,
vdim=None,
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
TransformerDecoderLayer(
d_ffn=d_ffn,
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
)
for _ in range(num_layers)
]
)
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
[docs] def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
"""
Arguments
----------
tgt : tensor
The sequence to the decoder layer (required).
memory : tensor
The sequence from the last layer of the encoder (required).
tgt_mask : tensor
The mask for the tgt sequence (optional).
memory_mask : tensor
The mask for the memory sequence (optional).
tgt_key_padding_mask : tensor
The mask for the tgt keys per batch (optional).
memory_key_padding_mask : tensor
The mask for the memory keys per batch (optional).
"""
output = tgt
self_attns, multihead_attns = [], []
for dec_layer in self.layers:
output, self_attn, multihead_attn = dec_layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
self_attns.append(self_attn)
multihead_attns.append(multihead_attn)
output = self.norm(output)
return output, self_attns, multihead_attns
[docs]class NormalizedEmbedding(nn.Module):
"""This class implements the normalized embedding layer for the transformer.
Since the dot product of the self-attention is always normalized by sqrt(d_model)
and the final linear projection for prediction shares weight with the embedding layer,
we multiply the output of the embedding by sqrt(d_model).
Arguments
---------
d_model: int
The number of expected features in the encoder/decoder inputs (default=512).
vocab: int
The vocab size.
Example
-------
>>> emb = NormalizedEmbedding(512, 1000)
>>> trg = torch.randint(0, 999, (8, 50))
>>> emb_fea = emb(trg)
"""
def __init__(self, d_model, vocab):
super().__init__()
self.emb = sb.nnet.embedding.Embedding(
num_embeddings=vocab, embedding_dim=d_model, blank_id=0
)
self.d_model = d_model
[docs]def get_key_padding_mask(padded_input, pad_idx):
"""Creates a binary mask to prevent attention to padded locations.
Arguments
----------
padded_input: int
Padded input.
pad_idx:
idx for padding element.
Example
-------
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_key_padding_mask(a, pad_idx=0)
tensor([[False, False, True],
[False, False, True],
[False, False, True]])
"""
if len(padded_input.shape) == 4:
bz, time, ch1, ch2 = padded_input.shape
padded_input = padded_input.reshape(bz, time, ch1 * ch2)
key_padded_mask = padded_input.eq(pad_idx).to(padded_input.device)
# if the input is more than 2d, mask the locations where they are silence
# across all channels
if len(padded_input.shape) > 2:
key_padded_mask = key_padded_mask.float().prod(dim=-1).bool()
return key_padded_mask.detach()
return key_padded_mask.detach()
[docs]def get_lookahead_mask(padded_input):
"""Creates a binary mask for each sequence.
Arguments
---------
padded_input : tensor
Padded input tensor.
Example
-------
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_lookahead_mask(a)
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0.]])
"""
seq_len = padded_input.shape[1]
mask = (
torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device))
== 1
).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask.detach().to(padded_input.device)