speechbrain.nnet.hypermixing module

This module mixes information from different tokens via HyperMixing. It can be viewed as a linear-time drop-in replacement for (self-)attention.

source: https://arxiv.org/abs/2203.03691

Authors
  • Florian Mai 2023

  • Juan Pablo Zuluaga 2023

Summary

Classes:

HyperMixing

This class implements multi-head HyperMixing.

HyperNetwork

This class implements The HyperNetwork.

ParallelMLPs

Class that implements the MultiHead HyperMixer or HyperConformer.

Reference

class speechbrain.nnet.hypermixing.HyperMixing(input_output_dim: int, hypernet_size: int, tied: bool = False, num_heads: int = 1, fix_tm_hidden_size=False, max_length=3000)[source]

Bases: Module

This class implements multi-head HyperMixing. It is an implementation of the token-mixing component in HyperMixer, a linear time drop-in replacement for self-attention. In contrast to the original HyperMixer, this module supports multiple heads, which improves the expressiveness of the model while decreasing the number of parameters.

Reference: https://arxiv.org/abs/2203.03691

Parameters:
  • input_output_dim (int) – number of features in keys, queries, and values

  • hypernet_size (int) – determines the size of the hidden layer of the token-mixing MLP.

  • tied (bool) – If True, then the generated weight matrices of the token-mixing MLP are tied.

  • num_heads (int) – parallel token-mixing MLPs.

  • fix_tm_hidden_size (bool) – If True, the hidden-layer size is equal to hypernet_size rather than hypernet_size / num_heads.

  • max_length (int) – Maximum number of input tokens. Needed for generating sufficiently large position embeddings.

Example

>>> import torch
>>> inputs = torch.rand([8, 60, 512])
>>> net = HyperMixing(512, 2048, num_heads=8)
>>> outputs, attn = net(inputs, inputs, inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool | None = True, pos_embs: Tensor | None = None)[source]

The signature of this method is deliberately chosen to be the same as for sb.nnet.attention.MultiHeadAttention for compatibility within SpeechBrain.

NOTE: key, value, attn_mask and pos_embs have no effect. Query is used for all three. Thus, the module should only be used to replace self-attention at the moment.

Parameters:
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension. Currently unused. All

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension. Currently unused.

  • attn_mask (torch.Tensor, optional) – NOTE: Currently has NO effect.

  • key_padding_mask (torch.Tensor, optional) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • return_attn_weights (torch.Tensor, optional) – NOTE: Currently has NO effect.

  • pos_embs (torch.Tensor, optional) – NOTE: Currently has NO effect.

  • Outputs

  • -------

  • attn_output (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_output_weights (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length. NOTE: always returns all zeros.

training: bool
class speechbrain.nnet.hypermixing.HyperNetwork(input_output_dim: int, hypernet_size: int, tied=False, num_heads=1, keep_output_size=True)[source]

Bases: Module

This class implements The HyperNetwork. It is an approach of using a one network, also known as a hypernetwork, to generate the weights for another network. Here, it is used to generate the labels of linear layers.

Reference: https://arxiv.org/abs/1609.09106

Parameters:
  • input_output_dim (int) – Dimension of the linear layers

  • hypernet_size – Dimension of the HyperNetwork

  • tied (bool, optional) – Define whether weights of layer 1 and layer 2 are shared

  • num_heads (int, optional) – Number of heads, akin to heads in MultiHeadAttention

  • keep_output_size (bool, optional) – Set whether to keep the same output size independent of number of heads

forward(input_tensor: Tensor)[source]

Forward computation for a HyperNetwork.

Parameters:
  • input_tensor ([batchsize, max_positions, d]) – The HyperNetwork is supposed to generate an MLP of the form W_2(GELU(W1 x)), where W1 : N -> k and W2 : k -> N, so it has to return tensors W1 and W2

  • Outputs

  • -------

  • W1 (torch.Tensor) – Generated weights of Layer 1

  • W2 (torch.Tensor) – Generated weights of Layer 2

training: bool
class speechbrain.nnet.hypermixing.ParallelMLPs(input_size, hidden_size, output_size=None, num_mlps=1, keep_output_size=True)[source]

Bases: Module

Class that implements the MultiHead HyperMixer or HyperConformer.

Parameters:
  • input_size (int) – Dimension of the linear layers

  • hidden_size (int) – Dimension of the hidden layer

  • output_size (int) – Dimension of the HyperNetwork

  • num_mlps (int) – Number of heads, akin to heads in MultiHeadAttention

  • keep_output_size (bool, optional) – Set whether to keep the same output size independent of number of heads

forward(x)[source]

Performs the forward computation of multi parallel MLPs.

Parameters:
  • x (torch.Tensor) – Input tensor

  • Outputs

  • -------

  • x – return output tensor

training: bool