Source code for speechbrain.lobes.models.beats

"""This lobe enables the integration of pretrained BEATs: Audio Pre-Training with Acoustic Tokenizers.

Reference: https://arxiv.org/abs/2212.09058
Based on Github source: https://github.com/microsoft/unilm/tree/master/beats

You could download the checkpoints from: https://github.com/microsoft/unilm/tree/master/beats

Author
 * Pooneh Mousavi 2024

"""

import logging
import math
import os
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as ta_kaldi
from torch import Tensor, nn
from torch.nn import LayerNorm, Parameter

from speechbrain.dataio.dataio import length_to_mask

logger = logging.getLogger(__name__)


[docs] class BEATs(nn.Module): """ BEATs: Audio Pre-Training with Acoustic Tokenizers. This class implements the BEATs model, which processes audio signals for feature extraction or downstream tasks. The model supports loading from a checkpoint, applying normalization, and optionally freezing parameters. Arguments --------- ckp_path : str, optional Path to the checkpoint file. If None, the model initializes without pre-trained weights. You could download the checkpoints from : https://github.com/microsoft/unilm/tree/master/beats freeze : bool, optional (default: False) If True, the model parameters are frozen and the model is set to evaluation mode. output_all_hiddens : bool, optional (default: False) If True, the forward function outputs hidden states from all transformer layers. For example BEATs_iter3 has 12 transformer layers and the output is of shape (13, B, T, C), where a projection of the CNN output is added to the beginning. If False, the forward function outputs the hidden states only from the last transformer layer. Example ------- >>> audio = torch.randn(4, 10000) # Batch of 4 audio signals >>> length = torch.tensor([1.0, 0.5, 0.75, 1.0]) >>> model = BEATs() >>> outputs = model.extract_features(audio, length)[0] >>> outputs.shape torch.Size([4, 24, 768]) """ def __init__( self, ckp_path: str = None, freeze: bool = True, output_all_hiddens: bool = False, ) -> None: super().__init__() # Load configuration and checkpoint cfg, checkpoint = None, None if ckp_path: if not os.path.exists(ckp_path): raise FileNotFoundError( f"Checkpoint file '{ckp_path}' does not exist." ) checkpoint = torch.load(ckp_path) cfg = checkpoint.get("cfg", None) # Initialize model configuration self.cfg = BEATsConfig(cfg) logger.info(f"BEATs Config: {self.cfg.__dict__}") # Model attributes self.freeze = freeze self.output_all_hiddens = output_all_hiddens self.embed = self.cfg.embed_dim # Define layers and modules self.post_extract_proj = ( nn.Linear(self.embed, self.cfg.encoder_embed_dim) if self.embed != self.cfg.encoder_embed_dim else None ) self.input_patch_size = self.cfg.input_patch_size self.patch_embedding = nn.Conv2d( 1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=self.cfg.conv_bias, ) self.dropout_input = nn.Dropout(self.cfg.dropout_input) # Configuration checks assert not ( self.cfg.deep_norm and self.cfg.layer_norm_first ), "Configuration error: 'deep_norm' and 'layer_norm_first' cannot both be True." # Initialize encoder and layer normalization self.encoder = TransformerEncoder(self.cfg) self.layer_norm = LayerNorm(self.embed) # Define predictor for fine-tuned models if self.cfg.finetuned_model: self.predictor_dropout = nn.Dropout(self.cfg.predictor_dropout) self.predictor = nn.Linear( self.cfg.encoder_embed_dim, self.cfg.predictor_class ) else: self.predictor = None # Load weights from the checkpoint if available if checkpoint: self.load_state_dict(checkpoint["model"]) # Set the model to evaluation mode if frozen if self.freeze: self.eval()
[docs] def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor ) -> torch.Tensor: """ Adjusts the padding mask for the given features. Arguments --------- features : torch.Tensor Input features after patch embedding. padding_mask : torch.Tensor Original padding mask for input signals. Returns ------- torch.Tensor Adjusted padding mask. """ extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view( padding_mask.size(0), features.size(1), -1 ) return padding_mask.all(-1)
[docs] def preprocess( self, source: torch.Tensor, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ) -> torch.Tensor: """ Preprocesses the input waveform by extracting filter banks and applying normalization. Arguments --------- source : torch.Tensor Input waveform signals. fbank_mean : float, optional Mean value for filter bank normalization (default: 15.41663). fbank_std : float, optional Standard deviation for filter bank normalization (default: 6.55582). Returns ------- torch.Tensor Normalized filter banks. """ fbanks = [] for waveform in source: waveform = waveform.unsqueeze(0) * 2**15 fbank = ta_kaldi.fbank( waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10, ) fbanks.append(fbank) fbank = torch.stack(fbanks, dim=0) return (fbank - fbank_mean) / (2 * fbank_std)
[docs] def forward( self, wav: torch.Tensor, wav_lens: Optional[torch.Tensor] = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ): """Takes an input waveform and return its corresponding beats encoding. Arguments --------- wav : torch.Tensor A batch of audio signals to transform to features. wav_lens : torch.Tensor The relative length of the wav given in SpeechBrain format. fbank_mean : float, optional Mean value for filter bank normalization (default: 15.41663). fbank_std : float, optional Standard deviation for filter bank normalization (default: 6.55582). Returns ------- BEATs encoded features. """ # If we freeze, we simply remove all grads from the graph. if self.freeze: with torch.no_grad(): return self.extract_features( wav, wav_lens, fbank_mean, fbank_std ) return self.extract_features(wav, wav_lens, fbank_mean, fbank_std)
[docs] def extract_features( self, wav: torch.Tensor, wav_lens: Optional[torch.Tensor] = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ) -> torch.Tensor: """ Extracts features from the input waveform. Arguments --------- wav : torch.Tensor A batch of audio signals to transform to features. wav_lens : torch.Tensor The relative length of the wav given in SpeechBrain format. fbank_mean : float, optional Mean value for filter bank normalization (default: 15.41663). fbank_std : float, optional Standard deviation for filter bank normalization (default: 6.55582). Returns ------- torch.Tensor Extracted features from the BEATs model. """ fbank = self.preprocess(wav, fbank_mean, fbank_std) if wav_lens is not None: max_len = wav.size(-1) padding_mask = ~length_to_mask( wav_lens * max_len, max_len, device=wav.device ).bool() if padding_mask is not None: padding_mask = self.forward_padding_mask(fbank, padding_mask) fbank = fbank.unsqueeze(1) features = self.patch_embedding(fbank) features = features.reshape( features.shape[0], features.shape[1], -1 ).transpose(1, 2) features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) x, layer_results = self.encoder( features, padding_mask=padding_mask, output_all_hiddens=self.output_all_hiddens, ) if self.predictor is not None: x_d = self.predictor_dropout(x) logits = self.predictor(x_d) if padding_mask is not None and padding_mask.any(): logits[padding_mask] = 0 logits = logits.sum(dim=1) logits = logits / (~padding_mask).sum(dim=1).unsqueeze( -1 ).expand_as(logits) else: logits = logits.mean(dim=1) lprobs = torch.sigmoid(logits) if self.output_all_hiddens: x = torch.stack(layer_results, dim=0) return x, lprobs, padding_mask if self.output_all_hiddens: x = torch.stack(layer_results, dim=0) return (x,)
[docs] def gelu_accurate(x): """ Applies the Gaussian Error Linear Unit (GELU) activation function using an accurate approximation. Arguments --------- x: torch.Tensor Input tensor on which to apply the GELU activation. Returns ------- torch.Tensor: Tensor with GELU activation applied element-wise. """ if not hasattr(gelu_accurate, "_a"): gelu_accurate._a = math.sqrt(2 / math.pi) return ( 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) )
[docs] def gelu(x: torch.Tensor) -> torch.Tensor: """ Applies the Gaussian Error Linear Unit (GELU) activation function. Arguments --------- x: torch.Tensor Input tensor to apply the GELU activation. Returns ------- torch.Tensor Tensor with GELU activation applied element-wise. """ return torch.nn.functional.gelu(x.float()).type_as(x)
[docs] def get_activation_fn(activation: str): """ Returns the activation function corresponding to the provided activation name. Arguments --------- activation : str Name of the activation function. Supported values: - "relu": Applies ReLU activation. - "gelu": Applies the GELU activation. - "gelu_fast": Alias for `gelu_accurate` with a deprecation warning. - "gelu_accurate": Applies the accurate GELU activation. - "tanh": Applies the Tanh activation. - "linear": Applies the identity function. - "glu": Applies the identity function (GLU placeholder). Returns ------- Callable[[torch.Tensor], torch.Tensor] The corresponding activation function to apply to input tensors. Raises ------ RuntimeError If the specified activation function is not supported. """ if activation == "relu": return F.relu elif activation == "gelu": return gelu elif activation == "gelu_fast": logger.warning( "--activation-fn=gelu_fast has been renamed to gelu_accurate" ) return gelu_accurate elif activation == "gelu_accurate": return gelu_accurate elif activation == "tanh": return torch.tanh elif activation == "linear": return lambda x: x elif activation == "glu": return lambda x: x else: raise RuntimeError( "--activation-fn {} not supported".format(activation) )
[docs] class SamePad(nn.Module): """ Implements a module that adjusts the padding of a tensor after convolution to maintain its original size, with an option for causal padding. This is particularly useful for handling padding in convolutional layers where the kernel size or causality affects the output size. Arguments --------- kernel_size : int The size of the convolutional kernel. causal : bool, optional (default=False) If True, applies causal padding by removing `(kernel_size - 1)` elements from the end of the tensor. If False, removes elements to center-align the padding, ensuring the output size matches the input size. """ def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0
[docs] def forward(self, x): """ Adjusts the padding of the input tensor `x`. If `self.remove > 0`, the method slices the tensor along the last dimension to remove excess padding based on the `kernel_size` and `causal` settings. Arguments --------- x : torch.Tensor The input tensor to adjust padding for. Returns ------- torch.Tensor The tensor with adjusted padding. """ if self.remove > 0: x = x[:, :, : -self.remove] return x
[docs] class Swish(nn.Module): """ Implements the Swish activation function as a PyTorch module. Swish is a smooth, non-monotonic activation function defined as: Swish(x) = x * sigmoid(x) It is often used in deep learning for its ability to improve training performance in certain architectures. """ def __init__(self): super(Swish, self).__init__() self.act = torch.nn.Sigmoid()
[docs] def forward(self, x): """ Applies the Swish activation function to the input tensor. Arguments --------- x : torch.Tensor The input tensor to which the Swish activation is applied. Returns ------- torch.Tensor The input tensor after applying the Swish activation. """ return x * self.act(x)
[docs] class GLU_Linear(nn.Module): """ Implements a Gated Linear Unit (GLU) combined with a linear transformation. Arguments --------- input_dim : int The dimensionality of the input features. output_dim : int The dimensionality of the output features. glu_type : str, optional (default="sigmoid") The type of activation function used for gating. Supported values are: - "sigmoid": Uses the sigmoid activation function. - "swish": Uses the Swish activation function. - "relu": Uses the ReLU activation function. - "gelu": Uses the GELU activation function. bias_in_glu : bool, optional (default=True) Whether to include a bias term in the linear transformation. """ def __init__( self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True ): super(GLU_Linear, self).__init__() self.glu_type = glu_type self.output_dim = output_dim if glu_type == "sigmoid": self.glu_act = torch.nn.Sigmoid() elif glu_type == "swish": self.glu_act = Swish() elif glu_type == "relu": self.glu_act = torch.nn.ReLU() elif glu_type == "gelu": self.glu_act = torch.nn.GELU() if bias_in_glu: self.linear = nn.Linear(input_dim, output_dim * 2, True) else: self.linear = nn.Linear(input_dim, output_dim * 2, False)
[docs] class GradMultiply(torch.autograd.Function): """ A custom autograd function that scales gradients during the backward pass. This is useful for scenarios where gradient scaling is required without affecting the forward pass output. The forward pass returns the input as-is, while the backward pass scales the gradients by a specified factor. """
[docs] @staticmethod def forward(ctx, x, scale): """ Performs the forward pass of the GradMultiply function. Arguments --------- ctx : torch.autograd.Function The context object to store information for the backward computation. x : torch.Tensor The input tensor to be forwarded unchanged. scale : float The factor by which the gradients will be scaled during the backward pass. Returns ------- torch.Tensor A new tensor identical to the input tensor. """ ctx.scale = scale res = x.new(x) return res
[docs] @staticmethod def backward(ctx, grad): """ Performs the backward pass, scaling the gradients by the stored factor. Arguments --------- ctx : torch.autograd.Function The context object containing the stored scaling factor. grad : torch.Tensor The gradient tensor from the subsequent layer. Returns ------- Tuple[torch.Tensor, None] The scaled gradient tensor and None (for the scale input, which has no gradient). """ return grad * ctx.scale, None
[docs] def quant_noise(module, p, block_size): """ Wraps modules and applies quantization noise to their weights for subsequent quantization using Iterative Product Quantization (iPQ). This approach is described in the paper: "Training with Quantization Noise for Extreme Model Compression." It introduces quantization noise during training to improve model robustness for extreme weight compression scenarios. Arguments --------- module : nn.Module The module to which quantization noise will be applied. Supported modules are Linear, Embedding, and Conv2d. p : float The amount of quantization noise to apply. Typically a probability or scaling factor. block_size : int The size of the blocks for subsequent quantization with iPQ. Returns ------- None """ # if no quantization noise, don't register hook if p <= 0: return module # supported modules assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) # test whether module.weight has the right sizes wrt block_size is_conv = module.weight.ndim == 4 # 2D matrix if not is_conv: assert ( module.weight.size(1) % block_size == 0 ), "Input features must be a multiple of block sizes" # 4D matrix else: # 1x1 convolutions if module.kernel_size == (1, 1): assert ( module.in_channels % block_size == 0 ), "Input channels must be a multiple of block sizes" # regular convolutions else: k = module.kernel_size[0] * module.kernel_size[1] assert ( k % block_size == 0 ), "Kernel size must be a multiple of block size"
[docs] class TransformerEncoder(nn.Module): """ Implements the Transformer Encoder module. Arguments --------- args : Namespace or dict A collection of model hyperparameters and configurations. """ def __init__(self, args): super().__init__() self.dropout = args.dropout self.embedding_dim = args.encoder_embed_dim self.pos_conv = nn.Conv1d( self.embedding_dim, self.embedding_dim, kernel_size=args.conv_pos, padding=args.conv_pos // 2, groups=args.conv_pos_groups, ) dropout = 0 std = math.sqrt( (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim) ) nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.constant_(self.pos_conv.bias, 0) self.pos_conv = nn.utils.weight_norm( self.pos_conv, name="weight", dim=2 ) self.pos_conv = nn.Sequential( self.pos_conv, SamePad(args.conv_pos), nn.GELU() ) if hasattr(args, "relative_position_embedding"): self.relative_position_embedding = args.relative_position_embedding self.num_buckets = args.num_buckets self.max_distance = args.max_distance else: self.relative_position_embedding = False self.num_buckets = 0 self.max_distance = 0 self.layers = nn.ModuleList( [ TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, deep_norm=args.deep_norm, has_relative_attention_bias=self.relative_position_embedding, num_buckets=self.num_buckets, max_distance=self.max_distance, gru_rel_pos=args.gru_rel_pos, encoder_layers=args.encoder_layers, ) for i in range(args.encoder_layers) ] ) if self.relative_position_embedding: for i in range(1, args.encoder_layers): del self.layers[i].self_attn.relative_attention_bias self.layers[i].self_attn.relative_attention_bias = self.layers[ 0 ].self_attn.relative_attention_bias self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) self.layerdrop = args.encoder_layerdrop self.apply(init_bert_params) if args.deep_norm: deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) for i in range(args.encoder_layers): nn.init.xavier_normal_( self.layers[i].self_attn.k_proj.weight, gain=1 ) nn.init.xavier_normal_( self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta ) nn.init.xavier_normal_( self.layers[i].self_attn.q_proj.weight, gain=1 ) nn.init.xavier_normal_( self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta, ) nn.init.xavier_normal_( self.layers[i].fc1.weight, gain=deep_norm_beta ) nn.init.xavier_normal_( self.layers[i].fc2.weight, gain=deep_norm_beta ) self.layer_wise_gradient_decay_ratio = getattr( args, "layer_wise_gradient_decay_ratio", 1 )
[docs] def forward(self, x, padding_mask=None, output_all_hiddens=None): """ Processes the input sequence through the Transformer Encoder layers. Arguments --------- x : torch.Tensor The input tensor of shape `(seq_len, batch_size, embed_dim)` containing the input embeddings. padding_mask : torch.Tensor, optional A binary mask of shape `(batch_size, seq_len)` indicating which positions are padding and should be ignored in attention computations. Default is `None`. output_all_hiddens : bool, optional If True, returns the hidden states from all encoder layers in addition to the final output. Default is `None`. Returns ------- Tuple[torch.Tensor, List[torch.Tensor]] - The final output tensor of shape `(seq_len, batch_size, embed_dim)`. """ x, layer_results = self.extract_features( x, padding_mask, output_all_hiddens ) if self.layer_norm_first and output_all_hiddens: x = self.layer_norm(x) return x, layer_results
[docs] def extract_features(self, x, padding_mask=None, output_all_hiddens=None): """ Extracts features from the input sequence using positional convolution, layer normalization, dropout, and a series of Transformer Encoder layers. Arguments --------- x : torch.Tensor The input tensor of shape `(batch_size, seq_len, embed_dim)` containing the input embeddings. padding_mask : torch.Tensor, optional A binary mask of shape `(batch_size, seq_len)` indicating which positions are padding and should be ignored in computations. Default is `None`. output_all_hiddens : bool, optional If True, collects and returns the hidden states from all encoder layers in addition to the final output. Default is `None`. Returns ------- Tuple[torch.Tensor, List[torch.Tensor]] - The final output tensor of shape `(batch_size, seq_len, embed_dim)`. """ if padding_mask is not None: x[padding_mask] = 0 x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) layer_results = [] z = None if output_all_hiddens: layer_results.append(x) r = None pos_bias = None for i, layer in enumerate(self.layers): if self.layer_wise_gradient_decay_ratio != 1.0: x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): x, z, pos_bias = layer( x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias, ) # if tgt_layer is not None: layer_results.append(x) if r is not None: x = r # T x B x C -> B x T x C x = x.transpose(0, 1) return x, layer_results
[docs] class TransformerSentenceEncoderLayer(nn.Module): """ Implements a single Transformer Sentence Encoder layer. Arguments --------- embedding_dim : float, optional (default=768) The dimensionality of input embeddings. ffn_embedding_dim : float, optional (default=3072) The dimensionality of the feed-forward network's hidden layer. num_attention_heads : float, optional (default=8) The number of attention heads for self-attention. dropout : float, optional (default=0.1) The dropout rate applied to the output of the feed-forward network and attention layers. attention_dropout : float, optional (default=0.1) The dropout rate applied within the attention mechanism. activation_dropout : float, optional (default=0.1) The dropout rate applied after the activation function in the feed-forward network. activation_fn : str, optional (default="relu") The activation function used in the feed-forward network. Supported values include "relu" and "gelu". layer_norm_first : bool, optional (default=False) If True, applies layer normalization before attention and feed-forward layers; otherwise, applies it afterward. deep_norm : bool, optional (default=False) If True, uses deep normalization scaling for residual connections. has_relative_attention_bias : bool, optional (default=False) If True, includes relative position bias in the attention mechanism. num_buckets : int, optional (default=0) The number of buckets used for relative attention bias (if enabled). max_distance : int, optional (default=0) The maximum distance for relative attention bias (if enabled). rescale_init : bool, optional (default=False) If True, rescales parameter initialization for improved stability. gru_rel_pos : bool, optional (default=False) If True, incorporates GRU-style relative position encoding. encoder_layers : int, optional (default=0) The number of encoder layers in the Transformer. """ def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, deep_norm: bool = False, has_relative_attention_bias: bool = False, num_buckets: int = 0, max_distance: int = 0, rescale_init: bool = False, gru_rel_pos: bool = False, encoder_layers: int = 0, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout self.activation_name = activation_fn self.activation_fn = get_activation_fn(activation_fn) self.self_attn = MultiheadAttention( self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True, has_relative_attention_bias=has_relative_attention_bias, num_buckets=num_buckets, max_distance=max_distance, rescale_init=rescale_init, gru_rel_pos=gru_rel_pos, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(self.activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first self.self_attn_layer_norm = LayerNorm(self.embedding_dim) if self.activation_name == "glu": self.fc1 = GLU_Linear( self.embedding_dim, ffn_embedding_dim, "swish" ) else: self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.final_layer_norm = LayerNorm(self.embedding_dim) self.deep_norm = deep_norm if self.deep_norm: self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) else: self.deep_norm_alpha = 1
[docs] def forward( self, x: torch.Tensor, self_attn_mask: torch.Tensor = None, self_attn_padding_mask: torch.Tensor = None, need_weights: bool = False, pos_bias=None, ): """ Processes the input tensor through the Transformer sentence encoder layer. Arguments --------- x : torch.Tensor Input tensor of shape `(seq_len, batch_size, embed_dim)`. self_attn_mask : torch.Tensor, optional Mask for the self-attention mechanism, typically used for causal or padding masking. Default is `None`. self_attn_padding_mask : torch.Tensor, optional Padding mask of shape `(batch_size, seq_len)`, indicating which tokens should be ignored in attention computations. Default is `None`. need_weights : bool, optional (default=False) Whether to return attention weights. If `True`, attention weights are included in the output. pos_bias : optional Positional bias for relative attention, if applicable. Default is `None`. Returns ------- Tuple[torch.Tensor, torch.Tensor, optional] - `x` (torch.Tensor): The output tensor of shape `(seq_len, batch_size, embed_dim)` after applying the encoder layer. """ residual = x if self.layer_norm_first: x = self.self_attn_layer_norm(x) x, attn, pos_bias = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False, attn_mask=self_attn_mask, position_bias=pos_bias, ) x = self.dropout1(x) x = residual + x residual = x x = self.final_layer_norm(x) if self.activation_name == "glu": x = self.fc1(x) else: x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual + x else: x, attn, pos_bias = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=need_weights, attn_mask=self_attn_mask, position_bias=pos_bias, ) x = self.dropout1(x) x = residual * self.deep_norm_alpha + x x = self.self_attn_layer_norm(x) residual = x if self.activation_name == "glu": x = self.fc1(x) else: x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual * self.deep_norm_alpha + x x = self.final_layer_norm(x) return x, attn, pos_bias
[docs] class MultiheadAttention(nn.Module): """ Implements multi-headed attention with support for advanced features like relative position embeddings and gated relative position embedding (GRU-based). Arguments --------- embed_dim : int Total number of dimensions for input embeddings. num_heads : int Number of attention heads. kdim : int, optional Dimensionality of key embeddings. Defaults to `embed_dim`. vdim : int, optional Dimensionality of value embeddings. Defaults to `embed_dim`. dropout : float, optional Dropout probability for attention weights. Defaults to 0.0. bias : bool, optional Whether to include a bias term in projections. Defaults to True. add_bias_kv : bool, optional Whether to include bias for key and value projections. Defaults to False. add_zero_attn : bool, optional Whether to include zero attention vectors. Defaults to False. self_attention : bool, optional Whether the layer is for self-attention. Defaults to False. encoder_decoder_attention : bool, optional Whether the layer is for encoder-decoder attention. Defaults to False. q_noise : float, optional Noise level for quantization. Defaults to 0.0. qn_block_size : int, optional Block size for quantization. Defaults to 8. has_relative_attention_bias : bool, optional Whether to use relative position embeddings. Defaults to False. num_buckets : int, optional Number of buckets for relative position embeddings. Defaults to 32. max_distance : int, optional Maximum distance for relative position embeddings. Defaults to 128. gru_rel_pos : bool, optional Whether to use gated relative position embeddings. Defaults to False. rescale_init : bool, optional Whether to rescale the initialization of weights. Defaults to False. """ # Initialization method def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, has_relative_attention_bias=False, num_buckets=32, max_distance=128, gru_rel_pos=False, rescale_init=False, ): super().__init__() # Attribute initialization self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout_module = nn.Dropout(dropout) self.has_relative_attention_bias = has_relative_attention_bias self.num_buckets = num_buckets self.max_distance = max_distance # Relative position bias setup if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) self.head_dim = embed_dim // num_heads self.q_head_dim = self.head_dim self.k_head_dim = self.head_dim assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 # Self-attention and encoder-decoder attention flags self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert ( not self.self_attention or self.qkv_same_dim ), "Self-attention requires query, key, and value to be of the same size." # Initialize projection layers with optional quantization noise self.k_proj = quant_noise( nn.Linear(self.kdim, embed_dim, bias=(not rescale_init)), q_noise, qn_block_size, ) self.v_proj = quant_noise( nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size ) self.q_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) self.out_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) # Bias terms for key and value, if applicable if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None # Additional settings self.add_zero_attn = add_zero_attn self.gru_rel_pos = gru_rel_pos if self.gru_rel_pos: self.grep_linear = nn.Linear(self.q_head_dim, 8) self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) # Reset parameters self.reset_parameters()
[docs] def reset_parameters(self): """ Initializes the weights for the projection layers and relative position embeddings. """ if self.qkv_same_dim: nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) else: nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) if self.has_relative_attention_bias: nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket( self, relative_positions, bidirectional=True ): """Computes bucket indices for relative positions for relative attention bias. Arguments --------- relative_positions : torch.Tensor A tensor of relative positions, where negative values indicate positions to the left and positive values indicate positions to the right. bidirectional : bool, optional, (default: True) If True, separate buckets are used for positive and negative positions. Returns ------- torch.Tensor A tensor of the same shape as `relative_positions`, where each value is the bucket index corresponding to the relative position. """ num_buckets = self.num_buckets max_distance = self.max_distance relative_buckets = 0 if bidirectional: # Halve buckets for bidirectional attention num_buckets = num_buckets // 2 relative_buckets += (relative_positions > 0).to( torch.long ) * num_buckets relative_positions = torch.abs(relative_positions) else: relative_positions = -torch.min( relative_positions, torch.zeros_like(relative_positions) ) max_exact = num_buckets // 2 is_small = relative_positions < max_exact relative_position_if_large = max_exact + ( torch.log(relative_positions.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1), ) relative_buckets += torch.where( is_small, relative_positions, relative_position_if_large ) return relative_buckets
[docs] def compute_bias(self, query_length: int, key_length: int) -> torch.Tensor: """ Computes relative position bias for attention scores. Arguments --------- query_length : int The length of the query sequence. key_length : int The length of the key sequence. Returns ------- torch.Tensor A tensor of shape `(num_heads, query_length, key_length)` containing the relative position bias values for each attention head. """ # Compute the relative position between each query and key token context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position # Map relative positions to bucket indices relative_position_bucket = self._relative_positions_bucket( relative_position, bidirectional=True ) # Move bucket indices to the device of the bias embeddings relative_position_bucket = relative_position_bucket.to( self.relative_attention_bias.weight.device ) # Fetch bias values from the relative position embedding layer values = self.relative_attention_bias(relative_position_bucket) # Rearrange dimensions to match expected output shape values = values.permute( [2, 0, 1] ) # Shape: (num_heads, query_length, key_length) return values
[docs] def forward( self, query: Tensor, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[ Dict[str, Dict[str, Optional[Tensor]]] ] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: """ Forward pass for multi-head attention with support for relative position embeddings, caching, and optional dropout. This method implements the core functionality of multi-head attention with optional features such as relative position bias, incremental decoding, and support for various masking options. Arguments --------- query : torch.Tensor Query tensor of shape `(target_length, batch_size, embed_dim)`. key : torch.Tensor, optional Key tensor of shape `(source_length, batch_size, embed_dim)`. Defaults to `None`. value : torch.Tensor, optional Value tensor of shape `(source_length, batch_size, embed_dim)`. Defaults to `None`. key_padding_mask : torch.Tensor, optional Mask to exclude padding keys, of shape `(batch_size, source_length)`, where padding elements are indicated by 1s. Defaults to `None`. incremental_state : dict, optional Stores cached key and value tensors for incremental decoding. Defaults to `None`. need_weights : bool, optional If True, returns the attention weights. Defaults to `True`. static_kv : bool, optional If True, the key and value tensors remain static for incremental decoding. Defaults to `False`. attn_mask : torch.Tensor, optional Attention mask to prevent certain positions from attending, typically for causal attention. Shape: `(target_length, source_length)`. Defaults to `None`. before_softmax : bool, optional If True, returns raw attention scores before softmax. Defaults to `False`. need_head_weights : bool, optional If True, returns attention weights for each head. Implies `need_weights=True`. Defaults to `False`. position_bias : torch.Tensor, optional Precomputed position bias tensor. If `None`, it is computed during the forward pass. Returns ------- attn : torch.Tensor Attention output of shape `(target_length, batch_size, embed_dim)`. attn_weights : torch.Tensor, optional Attention weights of shape `(batch_size, num_heads, target_length, source_length)`, averaged across heads if `need_head_weights=False`. position_bias : torch.Tensor, optional Computed or passed relative position bias of shape `(num_heads, target_length, source_length)`. """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if key is not None: src_len, key_bsz, _ = key.size() if not torch.jit.is_scripting(): assert key_bsz == bsz assert value is not None assert src_len, bsz == value.shape[:2] if self.has_relative_attention_bias and position_bias is None: position_bias = self.compute_bias(tgt_len, src_len) position_bias = ( position_bias.unsqueeze(0) .repeat(bsz, 1, 1, 1) .view(bsz * self.num_heads, tgt_len, src_len) ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert ( self.encoder_decoder_attention and not self.self_attention ) key = value = None else: saved_state = None alpha = 32 q, k, v, attn_mask, key_padding_mask = self._prepare_attention_inputs( query, key, value, bsz, tgt_len, key_padding_mask, attn_mask, alpha=32, ) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view( bsz * self.num_heads, -1, self.head_dim ) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) src_len = k.size(1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view( bsz * self.num_heads, -1, self.head_dim ) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view( bsz, self.num_heads, -1, self.head_dim ) saved_state["prev_value"] = v.view( bsz, self.num_heads, -1, self.head_dim ) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer( incremental_state, saved_state ) assert k is not None assert k.size(1) == src_len attn_weights, attn_mask = self._process_attention_weights( q, k, v, attn_mask, key_padding_mask, bsz, tgt_len, src_len, alpha ) if before_softmax: return attn_weights, v, position_bias attn, attn_weights = self._compute_attention_output( q, v, attn_weights, position_bias, bsz, tgt_len, src_len, embed_dim, need_weights, need_head_weights, alpha, ) return attn, attn_weights, position_bias
def _compute_attention_output( self, q, v, attn_weights, position_bias, bsz, tgt_len, src_len, embed_dim, need_weights, need_head_weights, alpha, ): """ Computes the final attention output, including relative position bias adjustments, attention weight computation, and attention projection. Arguments --------- q : torch.Tensor Query tensor. v : torch.Tensor Value tensor. attn_weights : torch.Tensor Attention weights tensor. position_bias : Optional[torch.Tensor] Relative position bias tensor. bsz : int Batch size. tgt_len : int Target sequence length. src_len : int Source sequence length. embed_dim : int Embedding dimension. need_weights : bool Whether to return attention weights. need_head_weights : bool Whether to return head-specific weights. alpha : float Scaling factor for relative position. Returns ------- Tuple[torch.Tensor, Optional[torch.Tensor]] Final attention output and optional attention weights. """ # Apply relative position bias if available if position_bias is not None: attn_mask_rel_pos = position_bias if self.gru_rel_pos == 1: query_layer = ( q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling ) _B, _H, _L, __ = query_layer.size() gate_a, gate_b = torch.sigmoid( self.grep_linear(query_layer) .view(_B, _H, _L, 2, 4) .sum(-1, keepdim=False) ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 attn_mask_rel_pos = ( gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias ) attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) attn_weights = attn_weights + attn_mask_rel_pos # Apply softmax and dropout attn_weights_float = F.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) # Compute final attention assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [ bsz * self.num_heads, tgt_len, self.head_dim, ] # Reshape and project attention output attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) # Optionally return attention weights attn_weights_out: Optional[Tensor] = None if need_weights: attn_weights_out = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if not need_head_weights: attn_weights_out = attn_weights_out.mean(dim=0) return attn, attn_weights_out def _process_attention_weights( self, q, k, v, attn_mask, key_padding_mask, bsz, tgt_len, src_len, alpha ): """ Processes attention weights, including handling key padding masks, adding zero attention if required, and computing the attention weights with masking. Arguments --------- q : torch.Tensor Query tensor. k : torch.Tensor Key tensor. v : torch.Tensor Value tensor. attn_mask : torch.Tensor Attention mask key_padding_mask : torch.Tensor Key padding mask. bsz : int Batch size. tgt_len : int Target sequence length. src_len : int Source sequence length. alpha : float Scaling factor for relative position. Returns ------- Tuple[torch.Tensor, Optional[torch.Tensor]] Computed attention weights and the updated attention mask. """ is_tpu = q.device.type == "xla" # Handle zero-dimension key padding mask if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None # Validate key padding mask dimensions if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len # Add zero attention if required if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat( [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1 ) v = torch.cat( [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1 ) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1, ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as( key_padding_mask ), ], dim=1, ) # Compute attention weights attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = ( attn_weights - attn_weights.max(dim=-1, keepdim=True)[0] ) * alpha attn_weights = self.apply_sparse_mask( attn_weights, tgt_len, src_len, bsz ) # Validate attention weights dimensions assert list(attn_weights.size()) == [ bsz * self.num_heads, tgt_len, src_len, ] # Apply attention mask if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask # Apply key padding mask if key_padding_mask is not None: attn_weights = attn_weights.view( bsz, self.num_heads, tgt_len, src_len ) if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float("-inf") ) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view( bsz * self.num_heads, tgt_len, src_len ) return attn_weights, attn_mask
[docs] def apply_bias(self, k, v, bsz, attn_mask=None, key_padding_mask=None): """ Applies bias_k and bias_v to the key and value tensors, updating the attention mask and key padding mask accordingly. Arguments --------- k : torch.Tensor Key tensor. v : torch.Tensor Value tensor. bsz : int Batch size. attn_mask : torch.Tensor Attention mask key_padding_mask : torch.Tensor Key padding mask. Returns ------- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: Updated key, value, attention mask, and key padding mask. """ if self.bias_k is not None: assert ( self.bias_v is not None ), "bias_k and bias_v must both be provided." # Apply biases to key and value k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)], dim=0) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)], dim=0) # Update attention mask if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1, ) # Update key padding mask if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) return k, v, attn_mask, key_padding_mask
def _prepare_attention_inputs( self, query, key, value, bsz, tgt_len, key_padding_mask=None, attn_mask=None, alpha=32, ): """ Prepares and scales the projections, applies biases, and reshapes the query, key, and value tensors for multi-head attention. Arguments --------- query : torch.Tensor Query tensor. key : torch.Tensor Key tensor. value : torch.Tensor Value tensor. bsz : int Batch size. tgt_len : int Target sequence length. key_padding_mask : torch.Tensor Key padding mask. attn_mask : torch.Tensor Attention mask alpha : float, optional Scaling factor for relative position. Default is 32. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] Scaled and reshaped query, key, and value tensors, along with updated attention and key padding masks. """ # Compute scaled projections if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) # Apply scaling q *= self.scaling q *= 1 / alpha # Reshape and transpose for multi-head attention q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.q_head_dim) .transpose(0, 1) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.num_heads, self.k_head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) return q, k, v, attn_mask, key_padding_mask @staticmethod def _append_prev_key_padding_mask( key_padding_mask: Optional[Tensor], prev_key_padding_mask: Optional[Tensor], batch_size: int, src_len: int, static_kv: bool, ) -> Optional[Tensor]: """ Combines the previous and current key padding masks to create a unified mask. Arguments --------- key_padding_mask : Optional[torch.Tensor] The current key padding mask of shape `(batch_size, seq_len)`, or `None`. prev_key_padding_mask : Optional[torch.Tensor] The previous key padding mask of shape `(batch_size, seq_len)`, or `None`. batch_size : int The batch size of the input. src_len : int The source sequence length to which the masks need to align. static_kv : bool If `True`, indicates that the key-value pairs are static and only the previous key padding mask should be used. Returns ------- Optional[torch.Tensor] The combined key padding mask of shape `(batch_size, src_len)`, or `None` if both input masks are `None`. """ # saved key padding masks have shape (bsz, seq_len) if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat( [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 ) # During incremental decoding, as the padding token enters and # leaves the frame, there will be a time when prev or current # is None elif prev_key_padding_mask is not None: if src_len > prev_key_padding_mask.size(1): filler = torch.zeros( (batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device, ) new_key_padding_mask = torch.cat( [prev_key_padding_mask.float(), filler.float()], dim=1 ) else: new_key_padding_mask = prev_key_padding_mask.float() elif key_padding_mask is not None: if src_len > key_padding_mask.size(1): filler = torch.zeros( (batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device, ) new_key_padding_mask = torch.cat( [filler.float(), key_padding_mask.float()], dim=1 ) else: new_key_padding_mask = key_padding_mask.float() else: new_key_padding_mask = prev_key_padding_mask return new_key_padding_mask def _get_input_buffer( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], ) -> Dict[str, Optional[Tensor]]: """ Retrieves the input buffer for incremental decoding. Arguments --------- incremental_state : Optional[Dict[str, Dict[str, Optional[Tensor]]]] The state dictionary used for incremental decoding. It stores intermediate computation states, such as attention states, for efficient sequential processing. Returns ------- Dict[str, Optional[Tensor]] The attention state dictionary containing keys and values for incremental decoding. If no state exists, an empty dictionary is returned. """ result = self.get_incremental_state(incremental_state, "attn_state") if result is not None: return result else: empty_result: Dict[str, Optional[Tensor]] = {} return empty_result def _set_input_buffer( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], buffer: Dict[str, Optional[Tensor]], ): """ Updates the input buffer for incremental decoding. Arguments --------- incremental_state : Dict[str, Dict[str, Optional[Tensor]]] The state dictionary used for incremental decoding. It stores intermediate computation states, such as attention states. buffer : Dict[str, Optional[Tensor]] The attention state dictionary containing keys and values to be stored for incremental decoding. Returns ------- None """ return self.set_incremental_state( incremental_state, "attn_state", buffer )
[docs] def apply_sparse_mask( self, attn_weights, tgt_len: int, src_len: int, bsz: int ): """ Applies a sparse mask to the attention weights. Arguments --------- attn_weights : torch.Tensor The attention weights tensor of shape `(batch_size * num_heads, tgt_len, src_len)`. tgt_len : int The target sequence length. src_len : int The source sequence length. bsz : int The batch size. Returns ------- torch.Tensor The (potentially modified) attention weights tensor. By default, this is the same as the input tensor. """ return attn_weights
[docs] def init_bert_params(module: nn.Module) -> None: """ Initializes weights and biases for modules in the BERT model. Arguments --------- module : nn.Module The module to initialize. Can be one of `nn.Linear`, `nn.Embedding`, or `MultiheadAttention`. """ def normal_(data: torch.Tensor) -> None: """ Initializes a tensor with values drawn from a normal distribution. Arguments --------- data : torch.Tensor The tensor to initialize. """ # Handle FSDP initialization data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): # Initialize weights and biases for linear layers normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): # Initialize weights for embedding layers normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, MultiheadAttention): # Initialize weights for multi-head attention projections normal_(module.q_proj.weight.data) normal_(module.k_proj.weight.data) normal_(module.v_proj.weight.data)
[docs] class BEATsConfig: """ Configuration class for the BEATs model. This class defines the configuration for the BEATs model. It provides a default configuration that can be updated with custom settings via the `update` method. Arguments --------- cfg : dict, optional A dictionary containing custom configuration values. If provided, it will override the default settings. """ def __init__(self, cfg=None): self.input_patch_size: int = 16 # path size of patch embedding self.embed_dim: int = 512 # patch embedding dimension self.conv_bias: bool = False # include bias in conv encoder self.encoder_layers: int = 12 # num encoder layers in the transformer self.encoder_embed_dim: int = 768 # encoder embedding dimension self.encoder_ffn_embed_dim: int = ( 3072 # encoder embedding dimension for FFN ) self.encoder_attention_heads: int = 12 # num encoder attention heads self.activation_fn: str = "gelu" # activation function to use self.layer_wise_gradient_decay_ratio: float = ( 1.0 # ratio for layer-wise gradient decay ) self.layer_norm_first: bool = ( False # apply layernorm first in the transformer ) self.deep_norm: bool = False # apply deep_norm first in the transformer # dropouts self.dropout: float = 0.1 # dropout probability for the transformer self.attention_dropout: float = ( 0.1 # dropout probability for attention weights ) self.activation_dropout: float = ( 0.0 # dropout probability after activation in FFN ) self.encoder_layerdrop: float = ( 0.0 # probability of dropping a tarnsformer layer ) self.dropout_input: float = ( 0.0 # dropout to apply to the input (after feat extr) ) # positional embeddings self.conv_pos: int = ( 128 # number of filters for convolutional positional embeddings ) self.conv_pos_groups: int = ( 16 # number of groups for convolutional positional embedding ) # relative position embedding self.relative_position_embedding: bool = ( False # apply relative position embedding ) self.num_buckets: int = ( 320 # number of buckets for relative position embedding ) self.max_distance: int = ( 1280 # maximum distance for relative position embedding ) self.gru_rel_pos: bool = ( False # apply gated relative position embedding ) # label predictor self.finetuned_model: bool = ( False # whether the model is a fine-tuned model. ) self.predictor_dropout: float = ( 0.1 # dropout probability for the predictor ) self.predictor_class: int = 527 # target class number for the predictor if cfg is not None: self.update(cfg)
[docs] def update(self, cfg: dict): """ Updates the instance's attributes with key-value pairs from a given configuration dictionary. Arguments --------- cfg : dict A dictionary containing the configuration values to update the instance with. """ self.__dict__.update(cfg)