Source code for speechbrain.lobes.models.transformer.Branchformer

"""Branchformer implementation.

Ref: "Branchformer: Parallel MLP-Attention Architectures
to Capture Local and Global Context for Speech Recognition and Understanding"

Source: Some parts of the code may be adapted from ESPNet.

Authors
* Titouan Parcollet 2023
"""

from typing import Optional

import torch
import torch.nn as nn

from speechbrain.lobes.models.convolution import ConvolutionalSpatialGatingUnit
from speechbrain.nnet.attention import MultiheadAttention, RelPosMHAXL
from speechbrain.nnet.hypermixing import HyperMixing
from speechbrain.nnet.normalization import LayerNorm


[docs] class ConvolutionBranch(nn.Module): """This is an implementation of the convolution branch in Branchformer. The default structure is: LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout Arguments --------- input_size : int The expected size of the feature (channel) dimension. linear_units: int, optional Number of neurons in the hidden linear units. kernel_size: int, optional Kernel size of non-bottleneck convolutional layer. activation: torch.nn.Module, optional Activation function used after pre projection. gate_activation: torch.nn.Module, optional Activation function used at the gate of the CSGU module. dropout: float, optional Dropout rate. use_linear_after_conv: bool, optional If True, will apply a linear transformation of size input_size//2 Example ------- >>> x = torch.rand((8, 60, 512)) >>> net = ConvolutionBranch(512, 1024) >>> output = net(x) >>> output.shape torch.Size([8, 60, 512]) """ def __init__( self, input_size, linear_units=3072, kernel_size=31, activation=nn.GELU, gate_activation=nn.Identity, dropout=0.0, use_linear_after_conv=False, ): super().__init__() self.pre_channel_proj = nn.Linear(input_size, linear_units) self.post_channel_proj = nn.Linear(linear_units // 2, input_size) self.activation = activation() self.csgu = ConvolutionalSpatialGatingUnit( input_size=linear_units, kernel_size=kernel_size, dropout=dropout, use_linear_after_conv=use_linear_after_conv, activation=gate_activation, )
[docs] def forward(self, x): """ Arguments ---------- x: torch.Tensor -> (B, T, D) """ x = self.activation(self.pre_channel_proj(x)) # (B, T, D) x = self.csgu(x) # (B, T, D//2) x = self.post_channel_proj(x) # (B, T, D) return x
[docs] class BranchformerEncoderLayer(nn.Module): """This is an implementation of Branchformer encoder layer. Arguments --------- d_model : int The expected size of the input embedding. nhead : int Number of attention heads. kernel_size : int, optional Kernel size of convolution model. kdim : int, optional Dimension of the key. vdim : int, optional Dimension of the value. activation: torch.nn.Module Activation function used in each Conformer layer. dropout : int, optional Dropout for the encoder. attention_type: str, optional type of attention layer, e.g. regularMHA for regular MultiHeadAttention. csgu_linear_units: int, optional Number of neurons in the hidden linear units of the CSGU Module. gate_activation: torch.nn.Module, optional Activation function used at the gate of the CSGU module. use_linear_after_conv: bool, optional If True, will apply a linear transformation of size input_size//2 Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_embs = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3) >>> output = net(x, pos_embs=pos_embs) >>> output[0].shape torch.Size([8, 60, 512]) """ def __init__( self, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=nn.GELU, dropout=0.0, attention_type="RelPosMHAXL", csgu_linear_units=3072, gate_activation=nn.Identity, use_linear_after_conv=False, ): super().__init__() if attention_type == "regularMHA": self.mha_layer = MultiheadAttention( nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim, ) elif attention_type == "RelPosMHAXL": # transformerXL style positional encoding self.mha_layer = RelPosMHAXL( num_heads=nhead, embed_dim=d_model, dropout=dropout, mask_pos_future=False, ) elif attention_type == "hypermixing": self.mha_layer = HyperMixing( input_output_dim=d_model, hypernet_size=d_model * 4, tied=False, num_heads=nhead, fix_tm_hidden_size=False, ) self.convolution_branch = ConvolutionBranch( input_size=d_model, kernel_size=kernel_size, linear_units=csgu_linear_units, activation=activation, gate_activation=gate_activation, dropout=dropout, use_linear_after_conv=use_linear_after_conv, ) self.merge_proj = torch.nn.Linear(d_model * 2, d_model) self.norm_mhsa = LayerNorm(d_model) self.norm_conv = LayerNorm(d_model) self.dropout = nn.Dropout(dropout)
[docs] def forward( self, x, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, ): """ Arguments ---------- x : torch.Tensor The sequence to the encoder layer. src_mask : torch.Tensor, optional The mask for the src sequence. src_key_padding_mask : torch.Tensor, optional The mask for the src keys per batch. pos_embs: torch.Tensor, torch.nn.Module, optional Module or tensor containing the input sequence positional embeddings """ # Two branches! x1 = x x2 = x # Branch 1: Self-attention x1 = self.norm_mhsa(x1) x1, self_attn = self.mha_layer( x1, x1, x1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, pos_embs=pos_embs, ) x1 = self.dropout(x1) # Branch 2: Convolutional gating MLP # In ESPnet, masks are not used?! we do the same but warning! x2 = self.norm_conv(x2) x2 = self.convolution_branch(x2) x2 = self.dropout(x2) # Merge both branches, we only do concatenation as it performs better. # According to the original Branchformer paper. x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1))) return x, self_attn
[docs] class BranchformerEncoder(nn.Module): """This class implements the Branchformer encoder. Arguments --------- num_layers : int Number of layers. d_model : int Embedding dimension size. nhead : int Number of attention heads. kernel_size : int, optional Kernel size of convolution model. kdim : int, optional Dimension of the key. vdim : int, optional Dimension of the value. activation: torch.nn.Module Activation function used in each Confomer layer. dropout : int, optional Dropout for the encoder. attention_type: str, optional type of attention layer, e.g. regularMHA for regular MultiHeadAttention. csgu_linear_units: int, optional Number of neurons in the hidden linear units of the CSGU Module. gate_activation: torch.nn.Module, optional Activation function used at the gate of the CSGU module. use_linear_after_conv: bool, optional If True, will apply a linear transformation of size input_size//2. output_hidden_states: bool, optional Whether the model should output the hidden states as a list of tensor. layerdrop_prob: float The probability to drop an entire layer. Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8) >>> output, _ = net(x, pos_embs=pos_emb) >>> output.shape torch.Size([8, 60, 512]) >>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8, output_hidden_states=True) >>> output, attn_list, hidden_list = net(x, pos_embs=pos_emb) >>> hidden_list[0].shape torch.Size([8, 60, 512]) >>> len(hidden_list) 2 """ def __init__( self, num_layers, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=nn.GELU, dropout=0.0, attention_type="RelPosMHAXL", csgu_linear_units=3072, gate_activation=nn.Identity, use_linear_after_conv=False, output_hidden_states=False, layerdrop_prob=0.0, ): super().__init__() self.layers = torch.nn.ModuleList( [ BranchformerEncoderLayer( nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout, activation=activation, kernel_size=kernel_size, attention_type=attention_type, csgu_linear_units=csgu_linear_units, gate_activation=gate_activation, use_linear_after_conv=use_linear_after_conv, ) for i in range(num_layers) ] ) self.norm = LayerNorm(d_model, eps=1e-6) self.layerdrop_prob = layerdrop_prob self.attention_type = attention_type self.output_hidden_states = output_hidden_states
[docs] def forward( self, src, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, dynchunktrain_config=None, ): """ Arguments --------- src : torch.Tensor The sequence to the encoder layer. src_mask : torch.Tensor, optional The mask for the src sequence. src_key_padding_mask : torch.Tensor, optional The mask for the src keys per batch. pos_embs: torch.Tensor, torch.nn.Module, Module or tensor containing the input sequence positional embeddings If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) where S is the sequence length, and E is the embedding dimension. dynchunktrain_config : None This configuration is unsupported for this encoder. Returns ------- output : torch.Tensor The output of the Conformer. attention_lst : list The attention values. hidden_state_lst : list, optional The output of the hidden layers of the encoder. Only works if output_hidden_states is set to true. """ assert ( dynchunktrain_config is None ), "Dynamic Chunk Training unsupported for this encoder" if self.attention_type == "RelPosMHAXL": if pos_embs is None: raise ValueError( "The chosen attention type for the Branchformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" ) output = src if self.layerdrop_prob > 0.0: keep_probs = torch.rand(len(self.layers)) attention_lst = [] if self.output_hidden_states: hidden_state_lst = [output] for i, enc_layer in enumerate(self.layers): if ( not self.training or self.layerdrop_prob == 0.0 or keep_probs[i] > self.layerdrop_prob ): output, attention = enc_layer( output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs, ) attention_lst.append(attention) if self.output_hidden_states: hidden_state_lst.append(output) output = self.norm(output) if self.output_hidden_states: return output, attention_lst, hidden_state_lst return output, attention_lst