Source code for speechbrain.lobes.models.transformer.Branchformer
"""Branchformer implementation.
Ref: "Branchformer: Parallel MLP-Attention Architectures
to Capture Local and Global Context for Speech Recognition and Understanding"
Source: Some parts of the code may be adapted from ESPNet.
Authors
* Titouan Parcollet 2023
"""
from typing import Optional
import torch
import torch.nn as nn
from speechbrain.lobes.models.convolution import ConvolutionalSpatialGatingUnit
from speechbrain.nnet.attention import MultiheadAttention, RelPosMHAXL
from speechbrain.nnet.hypermixing import HyperMixing
from speechbrain.nnet.normalization import LayerNorm
[docs]
class ConvolutionBranch(nn.Module):
"""This is an implementation of the convolution branch in Branchformer.
The default structure is:
LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout
Arguments
---------
input_size : int
The expected size of the feature (channel) dimension.
linear_units: int, optional
Number of neurons in the hidden linear units.
kernel_size: int, optional
Kernel size of non-bottleneck convolutional layer.
activation: torch.nn.Module, optional
Activation function used after pre projection.
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
dropout: float, optional
Dropout rate.
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2
Example
-------
>>> x = torch.rand((8, 60, 512))
>>> net = ConvolutionBranch(512, 1024)
>>> output = net(x)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
input_size,
linear_units=3072,
kernel_size=31,
activation=nn.GELU,
gate_activation=nn.Identity,
dropout=0.0,
use_linear_after_conv=False,
):
super().__init__()
self.pre_channel_proj = nn.Linear(input_size, linear_units)
self.post_channel_proj = nn.Linear(linear_units // 2, input_size)
self.activation = activation()
self.csgu = ConvolutionalSpatialGatingUnit(
input_size=linear_units,
kernel_size=kernel_size,
dropout=dropout,
use_linear_after_conv=use_linear_after_conv,
activation=gate_activation,
)
[docs]
def forward(self, x):
"""
Arguments
----------
x: torch.Tensor -> (B, T, D)
"""
x = self.activation(self.pre_channel_proj(x)) # (B, T, D)
x = self.csgu(x) # (B, T, D//2)
x = self.post_channel_proj(x) # (B, T, D)
return x
[docs]
class BranchformerEncoderLayer(nn.Module):
"""This is an implementation of Branchformer encoder layer.
Arguments
---------
d_model : int
The expected size of the input embedding.
nhead : int
Number of attention heads.
kernel_size : int, optional
Kernel size of convolution model.
kdim : int, optional
Dimension of the key.
vdim : int, optional
Dimension of the value.
activation: torch.nn.Module
Activation function used in each Conformer layer.
dropout : int, optional
Dropout for the encoder.
attention_type: str, optional
type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
csgu_linear_units: int, optional
Number of neurons in the hidden linear units of the CSGU Module.
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_embs = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3)
>>> output = net(x, pos_embs=pos_embs)
>>> output[0].shape
torch.Size([8, 60, 512])
"""
def __init__(
self,
d_model,
nhead,
kernel_size=31,
kdim=None,
vdim=None,
activation=nn.GELU,
dropout=0.0,
attention_type="RelPosMHAXL",
csgu_linear_units=3072,
gate_activation=nn.Identity,
use_linear_after_conv=False,
):
super().__init__()
if attention_type == "regularMHA":
self.mha_layer = MultiheadAttention(
nhead=nhead,
d_model=d_model,
dropout=dropout,
kdim=kdim,
vdim=vdim,
)
elif attention_type == "RelPosMHAXL":
# transformerXL style positional encoding
self.mha_layer = RelPosMHAXL(
num_heads=nhead,
embed_dim=d_model,
dropout=dropout,
mask_pos_future=False,
)
elif attention_type == "hypermixing":
self.mha_layer = HyperMixing(
input_output_dim=d_model,
hypernet_size=d_model * 4,
tied=False,
num_heads=nhead,
fix_tm_hidden_size=False,
)
self.convolution_branch = ConvolutionBranch(
input_size=d_model,
kernel_size=kernel_size,
linear_units=csgu_linear_units,
activation=activation,
gate_activation=gate_activation,
dropout=dropout,
use_linear_after_conv=use_linear_after_conv,
)
self.merge_proj = torch.nn.Linear(d_model * 2, d_model)
self.norm_mhsa = LayerNorm(d_model)
self.norm_conv = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
[docs]
def forward(
self,
x,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
):
"""
Arguments
----------
x : torch.Tensor
The sequence to the encoder layer.
src_mask : torch.Tensor, optional
The mask for the src sequence.
src_key_padding_mask : torch.Tensor, optional
The mask for the src keys per batch.
pos_embs: torch.Tensor, torch.nn.Module, optional
Module or tensor containing the input sequence positional embeddings
"""
# Two branches!
x1 = x
x2 = x
# Branch 1: Self-attention
x1 = self.norm_mhsa(x1)
x1, self_attn = self.mha_layer(
x1,
x1,
x1,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
)
x1 = self.dropout(x1)
# Branch 2: Convolutional gating MLP
# In ESPnet, masks are not used?! we do the same but warning!
x2 = self.norm_conv(x2)
x2 = self.convolution_branch(x2)
x2 = self.dropout(x2)
# Merge both branches, we only do concatenation as it performs better.
# According to the original Branchformer paper.
x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1)))
return x, self_attn
[docs]
class BranchformerEncoder(nn.Module):
"""This class implements the Branchformer encoder.
Arguments
---------
num_layers : int
Number of layers.
d_model : int
Embedding dimension size.
nhead : int
Number of attention heads.
kernel_size : int, optional
Kernel size of convolution model.
kdim : int, optional
Dimension of the key.
vdim : int, optional
Dimension of the value.
activation: torch.nn.Module
Activation function used in each Confomer layer.
dropout : int, optional
Dropout for the encoder.
attention_type: str, optional
type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
csgu_linear_units: int, optional
Number of neurons in the hidden linear units of the CSGU Module.
gate_activation: torch.nn.Module, optional
Activation function used at the gate of the CSGU module.
use_linear_after_conv: bool, optional
If True, will apply a linear transformation of size input_size//2.
output_hidden_states: bool, optional
Whether the model should output the hidden states as a list of tensor.
layerdrop_prob: float
The probability to drop an entire layer.
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_emb = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoder(1, 512, 8)
>>> output, _ = net(x, pos_embs=pos_emb)
>>> output.shape
torch.Size([8, 60, 512])
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> pos_emb = torch.rand((1, 2*60-1, 512))
>>> net = BranchformerEncoder(1, 512, 8, output_hidden_states=True)
>>> output, attn_list, hidden_list = net(x, pos_embs=pos_emb)
>>> hidden_list[0].shape
torch.Size([8, 60, 512])
>>> len(hidden_list)
2
"""
def __init__(
self,
num_layers,
d_model,
nhead,
kernel_size=31,
kdim=None,
vdim=None,
activation=nn.GELU,
dropout=0.0,
attention_type="RelPosMHAXL",
csgu_linear_units=3072,
gate_activation=nn.Identity,
use_linear_after_conv=False,
output_hidden_states=False,
layerdrop_prob=0.0,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
BranchformerEncoderLayer(
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
activation=activation,
kernel_size=kernel_size,
attention_type=attention_type,
csgu_linear_units=csgu_linear_units,
gate_activation=gate_activation,
use_linear_after_conv=use_linear_after_conv,
)
for i in range(num_layers)
]
)
self.norm = LayerNorm(d_model, eps=1e-6)
self.layerdrop_prob = layerdrop_prob
self.attention_type = attention_type
self.output_hidden_states = output_hidden_states
[docs]
def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
dynchunktrain_config=None,
):
"""
Arguments
---------
src : torch.Tensor
The sequence to the encoder layer.
src_mask : torch.Tensor, optional
The mask for the src sequence.
src_key_padding_mask : torch.Tensor, optional
The mask for the src keys per batch.
pos_embs: torch.Tensor, torch.nn.Module,
Module or tensor containing the input sequence positional embeddings
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
where S is the sequence length, and E is the embedding dimension.
dynchunktrain_config : None
This configuration is unsupported for this encoder.
Returns
-------
output : torch.Tensor
The output of the Conformer.
attention_lst : list
The attention values.
hidden_state_lst : list, optional
The output of the hidden layers of the encoder.
Only works if output_hidden_states is set to true.
"""
assert (
dynchunktrain_config is None
), "Dynamic Chunk Training unsupported for this encoder"
if self.attention_type == "RelPosMHAXL":
if pos_embs is None:
raise ValueError(
"The chosen attention type for the Branchformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
)
output = src
if self.layerdrop_prob > 0.0:
keep_probs = torch.rand(len(self.layers))
attention_lst = []
if self.output_hidden_states:
hidden_state_lst = [output]
for i, enc_layer in enumerate(self.layers):
if (
not self.training
or self.layerdrop_prob == 0.0
or keep_probs[i] > self.layerdrop_prob
):
output, attention = enc_layer(
output,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
)
attention_lst.append(attention)
if self.output_hidden_states:
hidden_state_lst.append(output)
output = self.norm(output)
if self.output_hidden_states:
return output, attention_lst, hidden_state_lst
return output, attention_lst