Source code for speechbrain.nnet.hypermixing

"""This module mixes information from different tokens via HyperMixing.
It can be viewed as a linear-time drop-in replacement for (self-)attention.

source: https://arxiv.org/abs/2203.03691

Authors
 * Florian Mai 2023
 * Juan Pablo Zuluaga 2023
"""
from typing import Optional

import math

import torch
from torch import nn


[docs] class HyperMixing(nn.Module): """ This class implements multi-head HyperMixing. It is an implementation of the token-mixing component in HyperMixer, a linear time drop-in replacement for self-attention. In contrast to the original HyperMixer, this module supports multiple heads, which improves the expressiveness of the model while decreasing the number of parameters. Reference: https://arxiv.org/abs/2203.03691 Arguments ---------- input_output_dim : int number of features in keys, queries, and values hypernet_size : int determines the size of the hidden layer of the token-mixing MLP. tied : bool If True, then the generated weight matrices of the token-mixing MLP are tied. num_heads : int parallel token-mixing MLPs. fix_tm_hidden_size : bool If True, the hidden-layer size is equal to hypernet_size rather than hypernet_size / num_heads. max_length : int Maximum number of input tokens. Needed for generating sufficiently large position embeddings. Example ------- >>> import torch >>> inputs = torch.rand([8, 60, 512]) >>> net = HyperMixing(512, 2048, num_heads=8) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512]) """ def __init__( self, input_output_dim: int, hypernet_size: int, tied: bool = False, num_heads: int = 1, fix_tm_hidden_size=False, max_length=3000, ) -> None: super().__init__() self.input_output_dim = input_output_dim self.hyper = HyperNetwork( input_output_dim, hypernet_size, tied=tied, num_heads=num_heads, keep_output_size=fix_tm_hidden_size, ) self.activation = nn.GELU() self.layer_norm = nn.LayerNorm(input_output_dim) self.num_heads = num_heads from speechbrain.lobes.models.transformer.Transformer import ( PositionalEncoding, ) # add pos encoding self.positional_encoding = PositionalEncoding( input_output_dim, max_length ) def _mlp_pass_from_components(self, out, W1, W2, activation): """function to stick MLP1 together manually""" out = torch.bmm(out, W1) out = activation(out) out = torch.bmm(out, W2.transpose(1, 2)) return out
[docs] def forward( self, query, key, value, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, return_attn_weights: Optional[bool] = True, pos_embs: Optional[torch.Tensor] = None, ): """ The signature of this method is deliberately chosen to be the same as for sb.nnet.attention.MultiHeadAttention for compatibility within SpeechBrain. NOTE: key, value, attn_mask and pos_embs have no effect. Query is used for all three. Thus, the module should only be used to replace self-attention at the moment. Arguments ---------- query : torch.Tensor (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension. key : torch.Tensor (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension. Currently unused. All value : torch.Tensor (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension. Currently unused. attn_mask : torch.Tensor, optional NOTE: Currently has NO effect. key_padding_mask : torch.Tensor, optional (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged. return_attn_weights: torch.Tensor, optional NOTE: Currently has NO effect. pos_embs: torch.Tensor, optional NOTE: Currently has NO effect. Outputs ------- attn_output : torch.Tensor (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension. attn_output_weights : torch.Tensor (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length. NOTE: always returns all zeros. """ # NOTE: We are ignoring keys and values, because HyperMixing can only be used in the encoder atm (where it's all the same) out = query bsize = out.size(0) seq_len = out.size(1) if key_padding_mask is not None: float_mask = ( torch.logical_not(key_padding_mask).unsqueeze(-1).float() ) out = out * float_mask # add position embedding before passing to hypernetwork hyp_input = out + self.positional_encoding(out) W1, W2 = self.hyper( hyp_input ) # [bsize, num_heads, seq_len, hypernet_size // num_heads] if key_padding_mask is not None: # mask the weights W1 = W1 * float_mask.unsqueeze(1) W2 = W2 * float_mask.unsqueeze(1) # reshape the num_heads into the batch dimension for parallelizing out = out.transpose(1, 2) # [bsize, input_output_dim, seq_len] out = out.reshape( ( bsize * self.num_heads, self.input_output_dim // self.num_heads, seq_len, ) ) # [bsize * num_heads, input_output_dim // num_heads, seq_len] W1 = W1.reshape((bsize * self.num_heads, seq_len, -1)) W2 = W2.reshape((bsize * self.num_heads, seq_len, -1)) # we stick the token-mixing MLP together manually out = self._mlp_pass_from_components(out, W1, W2, self.activation) # concatenate heads out = out.reshape((bsize, self.input_output_dim, seq_len)) # transpose back out = out.transpose(1, 2) # apply layer norm on outputs of the TM-MLP out = self.layer_norm(out) dummy_att_weights = torch.zeros( (bsize, seq_len, seq_len), device=out.device ) return out, dummy_att_weights
[docs] class HyperNetwork(nn.Module): """This class implements The HyperNetwork. It is an approach of using a one network, also known as a hypernetwork, to generate the weights for another network. Here, it is used to generate the labels of linear layers. Reference: https://arxiv.org/abs/1609.09106 Arguments ---------- input_output_dim : int Dimension of the linear layers hypernet_size: Dimension of the HyperNetwork tied : bool, optional Define whether weights of layer 1 and layer 2 are shared num_heads: int, optional Number of heads, akin to heads in MultiHeadAttention keep_output_size: bool, optional Set whether to keep the same output size independent of number of heads """ def __init__( self, input_output_dim: int, hypernet_size: int, tied=False, num_heads=1, keep_output_size=True, ) -> None: super(HyperNetwork, self).__init__() # Define whether the two linear layers have tied weights self.tied = tied self.w1_gen = ParallelMLPs( input_output_dim, input_output_dim, output_size=hypernet_size, num_mlps=num_heads, keep_output_size=keep_output_size, ) if self.tied: self.w2_gen = self.w1_gen else: self.w2_gen = ParallelMLPs( input_output_dim, input_output_dim, output_size=hypernet_size, num_mlps=num_heads, keep_output_size=keep_output_size, )
[docs] def forward(self, input_tensor: torch.Tensor): """ Forward computation for a HyperNetwork. Arguments ---------- input_tensor : [batchsize, max_positions, d] The HyperNetwork is supposed to generate an MLP of the form W_2(GELU(W1 x)), where W1 : N -> k and W2 : k -> N, so it has to return tensors W1 and W2 Outputs ------- W1 : torch.Tensor Generated weights of Layer 1 W2 : torch.Tensor Generated weights of Layer 2 """ W1 = self.w1_gen(input_tensor) if self.tied: W2 = W1 else: W2 = self.w2_gen(input_tensor) return W1, W2
[docs] class ParallelMLPs(nn.Module): """Class that implements the MultiHead HyperMixer or HyperConformer. Arguments ---------- input_size : int Dimension of the linear layers hidden_size: int Dimension of the hidden layer output_size : int Dimension of the HyperNetwork num_mlps : int Number of heads, akin to heads in MultiHeadAttention keep_output_size : bool, optional Set whether to keep the same output size independent of number of heads """ def __init__( self, input_size, hidden_size, output_size=None, num_mlps=1, keep_output_size=True, ) -> None: super(ParallelMLPs, self).__init__() if output_size is None: output_size = input_size self.original_in_size = input_size self.original_out_size = output_size assert input_size % num_mlps == 0 assert output_size % num_mlps == 0 assert hidden_size % num_mlps == 0 input_size = input_size // num_mlps if not keep_output_size: output_size = output_size // num_mlps hidden_size = hidden_size // num_mlps self.input_size = input_size self.output_size = output_size self.num_mlps = num_mlps # set the weights and biases parameters self.fc1_weights = nn.Parameter( torch.empty(num_mlps, hidden_size, input_size) ) self.fc1_biases = nn.Parameter(torch.empty(num_mlps, hidden_size)) self.fc2_weights = nn.Parameter( torch.empty(num_mlps, output_size, hidden_size) ) self.fc2_biases = nn.Parameter(torch.empty(num_mlps, output_size)) # initialize the weights and biases nn.init.xavier_uniform_(self.fc1_weights, gain=math.sqrt(2.0)) nn.init.xavier_uniform_(self.fc1_biases, gain=math.sqrt(2.0)) nn.init.xavier_uniform_(self.fc2_weights, gain=math.sqrt(2.0)) nn.init.xavier_uniform_(self.fc2_biases, gain=math.sqrt(2.0)) self.activation = nn.GELU()
[docs] def forward(self, x): """Performs the forward computation of multi parallel MLPs. Arguments ---------- x : tensor Input tensor Outputs ------- x : torch.Tensor return output tensor """ # x [bsize, seq_len, num_features] bsize = x.size(0) seq_len = x.size(1) # Reshape the input tensor to match the number of parallel MLPs and their input size x = x.reshape((bsize, seq_len, self.num_mlps, self.input_size)) # Perform the first linear transformation and add bias # Using einsum so we can do it for multiple MLPs in parallel x = torch.einsum( "blmf,mhf->bmlh", x, self.fc1_weights ) + self.fc1_biases.unsqueeze(0).unsqueeze(2) # Apply activation function and perform the second linear transformation and add bias x = self.activation(x) x = torch.einsum( "bmlh,mfh->bmlf", x, self.fc2_weights ) + self.fc2_biases.unsqueeze(0).unsqueeze(2) return x