Source code for speechbrain.lobes.models.transformer.Branchformer
"""Branchformer implementation.
Ref: "Branchformer: Parallel MLP-Attention Architectures
to Capture Local and Global Context for Speech Recognition and Understanding"
Source: Some parts of the code may be adapted from ESPNet.
Authors
* Titouan Parcollet 2023
"""
import torch
import torch.nn as nn
from typing import Optional
from speechbrain.nnet.attention import RelPosMHAXL, MultiheadAttention
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.lobes.models.convolution import ConvolutionalSpatialGatingUnit
from speechbrain.nnet.hypermixing import HyperMixing
[docs]
class ConvolutionBranch(nn.Module):
"""This is an implementation of the convolution branch in Branchformer.
The default structure is:
LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout
Arguments
----------
input_size : int
The expected size of the feature (channel) dimension.
linear_units: int, optional
Number of neurons in the hidden linear units.
kernel_size: int, optional
Kernel size of non-bottleneck convolutional layer.
activation: torch.nn.Module, optional
Activation function used after pre projection.
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
dropout: float, optional
Dropout rate.
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2
Example
-------
>>> x = torch.rand((8, 60, 512))
>>> net = ConvolutionBranch(512, 1024)
>>> output = net(x)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
input_size,
linear_units=3072,
kernel_size=31,
activation=nn.GELU,
gate_activation=nn.Identity,
dropout=0.0,
use_linear_after_conv=False,
):
super().__init__()
self.pre_channel_proj = nn.Linear(input_size, linear_units)
self.post_channel_proj = nn.Linear(linear_units // 2, input_size)
self.activation = activation()
self.csgu = ConvolutionalSpatialGatingUnit(
input_size=linear_units,
kernel_size=kernel_size,
dropout=dropout,
use_linear_after_conv=use_linear_after_conv,
activation=gate_activation,
)
[docs]
def forward(self, x):
"""
Arguments
----------
x: torch.Tensor -> (B, T, D)
"""
x = self.activation(self.pre_channel_proj(x)) # (B, T, D)
x = self.csgu(x) # (B, T, D//2)
x = self.post_channel_proj(x) # (B, T, D)
return x
[docs]
class BranchformerEncoderLayer(nn.Module):
"""This is an implementation of Branchformer encoder layer.
Arguments
----------
d_model : int
The expected size of the input embedding.
nhead : int
Number of attention heads.
kernel_size : int, optional
Kernel size of convolution model.
kdim : int, optional
Dimension of the key.
vdim : int, optional
Dimension of the value.
activation: torch.nn.Module
Activation function used in each Conformer layer.
dropout : int, optional
Dropout for the encoder.
attention_type: str, optional
type of attention layer, e.g. regulaMHA for regular MultiHeadAttention.
csgu_linear_units: int, optional
Number of neurons in the hidden linear units of the CSGU Module.
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_embs = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3)
>>> output = net(x, pos_embs=pos_embs)
>>> output[0].shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
d_model,
nhead,
kernel_size=31,
kdim=None,
vdim=None,
activation=nn.GELU,
dropout=0.0,
attention_type="RelPosMHAXL",
csgu_linear_units=3072,
gate_activation=nn.Identity,
use_linear_after_conv=False,
):
super().__init__()
if attention_type == "regularMHA":
self.mha_layer = MultiheadAttention(
nhead=nhead,
d_model=d_model,
dropout=dropout,
kdim=kdim,
vdim=vdim,
)
elif attention_type == "RelPosMHAXL":
# transformerXL style positional encoding
self.mha_layer = RelPosMHAXL(
num_heads=nhead,
embed_dim=d_model,
dropout=dropout,
mask_pos_future=False,
)
elif attention_type == "hypermixing":
self.mha_layer = HyperMixing(
input_output_dim=d_model,
hypernet_size=d_model * 4,
tied=False,
num_heads=nhead,
fix_tm_hidden_size=False,
)
self.convolution_branch = ConvolutionBranch(
input_size=d_model,
kernel_size=kernel_size,
linear_units=csgu_linear_units,
activation=activation,
gate_activation=gate_activation,
dropout=dropout,
use_linear_after_conv=use_linear_after_conv,
)
self.merge_proj = torch.nn.Linear(d_model * 2, d_model)
self.norm_mhsa = LayerNorm(d_model)
self.norm_conv = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
[docs]
def forward(
self,
x,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
):
"""
Arguments
----------
x : torch.Tensor
The sequence to the encoder layer.
src_mask : torch.Tensor, optional
The mask for the src sequence.
src_key_padding_mask : torch.Tensor, optional
The mask for the src keys per batch.
pos_embs: torch.Tensor, torch.nn.Module, optional
Module or tensor containing the input sequence positional embeddings
"""
# Two branches!
x1 = x
x2 = x
# Branch 1: Self-attention
x1 = self.norm_mhsa(x1)
x1, self_attn = self.mha_layer(
x1,
x1,
x1,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
)
x1 = self.dropout(x1)
# Branch 2: Convolutional gating MLP
# In ESPnet, masks are not used?! we do the same but warning!
x2 = self.norm_conv(x2)
x2 = self.convolution_branch(x2)
x2 = self.dropout(x2)
# Merge both branches, we only do concatenation as it performs better.
# According to the original Branchformer paper.
x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1)))
return x, self_attn
[docs]
class BranchformerEncoder(nn.Module):
"""This class implements the Branchformer encoder.
Arguments
---------
num_layers : int
Number of layers.
d_model : int
Embedding dimension size.
nhead : int
Number of attention heads.
kernel_size : int, optional
Kernel size of convolution model.
kdim : int, optional
Dimension of the key.
vdim : int, optional
Dimension of the value.
activation: torch.nn.Module
Activation function used in each Confomer layer.
dropout : int, optional
Dropout for the encoder.
attention_type: str, optional
type of attention layer, e.g. regulaMHA for regular MultiHeadAttention.
csgu_linear_units: int, optional
Number of neurons in the hidden linear units of the CSGU Module.
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2.
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_emb = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoder(1, 512, 8)
>>> output, _ = net(x, pos_embs=pos_emb)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
num_layers,
d_model,
nhead,
kernel_size=31,
kdim=None,
vdim=None,
activation=nn.GELU,
dropout=0.0,
attention_type="RelPosMHAXL",
csgu_linear_units=3072,
gate_activation=nn.Identity,
use_linear_after_conv=False,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
BranchformerEncoderLayer(
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
activation=activation,
kernel_size=kernel_size,
attention_type=attention_type,
csgu_linear_units=csgu_linear_units,
gate_activation=gate_activation,
use_linear_after_conv=use_linear_after_conv,
)
for i in range(num_layers)
]
)
self.norm = LayerNorm(d_model, eps=1e-6)
self.attention_type = attention_type
[docs]
def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
dynchunktrain_config=None,
):
"""
Arguments
----------
src : torch.Tensor
The sequence to the encoder layer.
src_mask : torch.Tensor, optional
The mask for the src sequence.
src_key_padding_mask : torch.Tensor, optional
The mask for the src keys per batch.
pos_embs: torch.Tensor, torch.nn.Module,
Module or tensor containing the input sequence positional embeddings
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
where S is the sequence length, and E is the embedding dimension.
"""
assert (
dynchunktrain_config is None
), "Dynamic Chunk Training unsupported for this encoder"
if self.attention_type == "RelPosMHAXL":
if pos_embs is None:
raise ValueError(
"The chosen attention type for the Branchformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
)
output = src
attention_lst = []
for enc_layer in self.layers:
output, attention = enc_layer(
output,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
)
attention_lst.append(attention)
output = self.norm(output)
return output, attention_lst