Source code for speechbrain.lobes.models.transformer.Branchformer

"""Branchformer implementation.

Ref: "Branchformer: Parallel MLP-Attention Architectures
to Capture Local and Global Context for Speech Recognition and Understanding"

Source: Some parts of the code may be adapted from ESPNet.

Authors
* Titouan Parcollet 2023
"""

import torch
import torch.nn as nn
from typing import Optional

from speechbrain.nnet.attention import RelPosMHAXL, MultiheadAttention
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.lobes.models.convolution import ConvolutionalSpatialGatingUnit


[docs]class ConvolutionBranch(nn.Module): """This is an implementation of the convolution branch in Branchformer. The default structure is: LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout Arguments ---------- input_size : int The expected size of the feature (channel) dimension. linear_units: int, optional Number of neurons in the hidden linear units. kernel_size: int, optional Kernel size of non-bottleneck convolutional layer. activation: torch.nn.Module, optional Activation function used after pre projection. gate_activation: torch.nn.Module, optional Activation function used at the gate of the CSGU module. dropout: float, optional Dropout rate. use_linear_after_conv: bool, optional If True, will apply a linear transformation of size input_size//2 Example ------- >>> x = torch.rand((8, 60, 512)) >>> net = ConvolutionBranch(512, 1024) >>> output = net(x) >>> output.shape torch.Size([8, 60, 512]) """ def __init__( self, input_size, linear_units=3072, kernel_size=31, activation=nn.GELU, gate_activation=nn.Identity, dropout=0.0, use_linear_after_conv=False, ): super().__init__() self.pre_channel_proj = nn.Linear(input_size, linear_units) self.post_channel_proj = nn.Linear(linear_units // 2, input_size) self.activation = activation() self.csgu = ConvolutionalSpatialGatingUnit( input_size=linear_units, kernel_size=kernel_size, dropout=dropout, use_linear_after_conv=use_linear_after_conv, activation=gate_activation, )
[docs] def forward(self, x): """ Arguments ---------- x: torch.Tensor -> (B, T, D) """ x = self.activation(self.pre_channel_proj(x)) # (B, T, D) x = self.csgu(x) # (B, T, D//2) x = self.post_channel_proj(x) # (B, T, D) return x
[docs]class BranchformerEncoderLayer(nn.Module): """This is an implementation of Branchformer encoder layer. Arguments ---------- d_model : int The expected size of the input embedding. nhead : int Number of attention heads. kernel_size : int, optional Kernel size of convolution model. kdim : int, optional Dimension of the key. vdim : int, optional Dimension of the value. activation: torch.nn.Module Activation function used in each Conformer layer. dropout : int, optional Dropout for the encoder. attention_type: str, optional type of attention layer, e.g. regulaMHA for regular MultiHeadAttention. csgu_linear_units: int, optional Number of neurons in the hidden linear units of the CSGU Module. gate_activation: torch.nn.Module, optional Activation function used at the gate of the CSGU module. use_linear_after_conv: bool, optional If True, will apply a linear transformation of size input_size//2 Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_embs = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3) >>> output = net(x, pos_embs=pos_embs) >>> output[0].shape torch.Size([8, 60, 512]) """ def __init__( self, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=nn.GELU, dropout=0.0, attention_type="RelPosMHAXL", csgu_linear_units=3072, gate_activation=nn.Identity, use_linear_after_conv=False, ): super().__init__() if attention_type == "regularMHA": self.mha_layer = MultiheadAttention( nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim, ) elif attention_type == "RelPosMHAXL": # transformerXL style positional encoding self.mha_layer = RelPosMHAXL( num_heads=nhead, embed_dim=d_model, dropout=dropout, mask_pos_future=False, ) self.convolution_branch = ConvolutionBranch( input_size=d_model, kernel_size=kernel_size, linear_units=csgu_linear_units, activation=activation, gate_activation=gate_activation, dropout=dropout, use_linear_after_conv=use_linear_after_conv, ) self.merge_proj = torch.nn.Linear(d_model * 2, d_model) self.norm_mhsa = LayerNorm(d_model) self.norm_conv = LayerNorm(d_model) self.dropout = nn.Dropout(dropout)
[docs] def forward( self, x, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, ): """ Arguments ---------- x : torch.Tensor The sequence to the encoder layer. src_mask : torch.Tensor, optional The mask for the src sequence. src_key_padding_mask : torch.Tensor, optional The mask for the src keys per batch. pos_embs: torch.Tensor, torch.nn.Module, optional Module or tensor containing the input sequence positional embeddings """ # Two branches! x1 = x x2 = x # Branch 1: Self-attention x1 = self.norm_mhsa(x1) x1, self_attn = self.mha_layer( x1, x1, x1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, pos_embs=pos_embs, ) x1 = self.dropout(x1) # Branch 2: Convolutional gating MLP # In ESPnet, masks are not used?! we do the same but warning! x2 = self.norm_conv(x2) x2 = self.convolution_branch(x2) x2 = self.dropout(x2) # Merge both branches, we only do concatenation as it performs better. # According to the original Branchformer paper. x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1))) return x, self_attn
[docs]class BranchformerEncoder(nn.Module): """This class implements the Branchformer encoder. Arguments --------- num_layers : int Number of layers. d_model : int Embedding dimension size. nhead : int Number of attention heads. kernel_size : int, optional Kernel size of convolution model. kdim : int, optional Dimension of the key. vdim : int, optional Dimension of the value. activation: torch.nn.Module Activation function used in each Confomer layer. dropout : int, optional Dropout for the encoder. attention_type: str, optional type of attention layer, e.g. regulaMHA for regular MultiHeadAttention. csgu_linear_units: int, optional Number of neurons in the hidden linear units of the CSGU Module. gate_activation: torch.nn.Module, optional Activation function used at the gate of the CSGU module. use_linear_after_conv: bool, optional If True, will apply a linear transformation of size input_size//2. Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8) >>> output, _ = net(x, pos_embs=pos_emb) >>> output.shape torch.Size([8, 60, 512]) """ def __init__( self, num_layers, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=nn.GELU, dropout=0.0, attention_type="RelPosMHAXL", csgu_linear_units=3072, gate_activation=nn.Identity, use_linear_after_conv=False, ): super().__init__() self.layers = torch.nn.ModuleList( [ BranchformerEncoderLayer( nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout, activation=activation, kernel_size=kernel_size, attention_type=attention_type, csgu_linear_units=csgu_linear_units, gate_activation=gate_activation, use_linear_after_conv=use_linear_after_conv, ) for i in range(num_layers) ] ) self.norm = LayerNorm(d_model, eps=1e-6) self.attention_type = attention_type
[docs] def forward( self, src, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, ): """ Arguments ---------- src : torch.Tensor The sequence to the encoder layer. src_mask : torch.Tensor, optional The mask for the src sequence. src_key_padding_mask : torch.Tensor, optional The mask for the src keys per batch. pos_embs: torch.Tensor, torch.nn.Module, Module or tensor containing the input sequence positional embeddings If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) where S is the sequence length, and E is the embedding dimension. """ if self.attention_type == "RelPosMHAXL": if pos_embs is None: raise ValueError( "The chosen attention type for the Branchformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" ) output = src attention_lst = [] for enc_layer in self.layers: output, attention = enc_layer( output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs, ) attention_lst.append(attention) output = self.norm(output) return output, attention_lst