speechbrain.nnet.attention module

Library implementing attention modules.

Authors
  • Ju-Chieh Chou 2020

  • Jianyuan Zhong 2020

  • Loren Lugosch 2020

  • Samuele Cornell 2020

  • Shucong Zhang 2024

Summary

Classes:

ContentBasedAttention

This class implements content-based attention module for seq2seq learning.

KeyValueAttention

This class implements a single-headed key-value attention module for seq2seq learning.

LocationAwareAttention

This class implements location-aware attention module for seq2seq learning.

MemoiseAtLeastSize

Memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with.

MultiheadAttention

The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.

PositionalwiseFeedForward

The class implements the positional-wise feed forward module in β€œAttention Is All You Need”.

PrecomputedRoPESinusoids

A cache for the sines and cosines needed to rotate the vectors for rotary position embeddings (RoPE).

RelPosEncXL

Relative positional encoding for the RelPosMHAXL.

RelPosMHAXL

This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf

RoPEMHA

This is an implementation of multihead self-attention with RoPE positional embeddings.

Functions:

masks_union

This is an utility function combining standard key_padding_mask and attn_mask from SpeechBrain into a single one for scaled_dot_product_attention.

memoise_at_least

Decorator that memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with.

Reference

class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]

Bases: Module

This class implements content-based attention module for seq2seq learning.

Reference: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, Bahdanau et.al. https://arxiv.org/pdf/1409.0473.pdf

Parameters:
  • enc_dim (int) – Size of encoder layer.

  • dec_dim (int) – Size of decoder layer.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

  • scaling (float) – The factor controls the sharpening degree (default: 1.0).

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = ContentBasedAttention(
...     enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5
... )
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
reset()[source]

Reset the memory in the attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters:
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

Return type:

The output of the attention module.

class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]

Bases: Module

This class implements location-aware attention module for seq2seq learning.

Reference: Attention-Based Models for Speech Recognition, Chorowski et.al. https://arxiv.org/pdf/1506.07503.pdf

Parameters:
  • enc_dim (int) – Size of encoder.

  • dec_dim (int) – Size of decoder.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

  • conv_channels (int) – Number of channel for location feature.

  • kernel_size (int) – Kernel size of convolutional layer for location feature.

  • scaling (float) – The factor controls the sharpening degree (default: 1.0).

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = LocationAwareAttention(
...     enc_dim=20,
...     dec_dim=25,
...     attn_dim=30,
...     output_dim=5,
...     conv_channels=10,
...     kernel_size=100,
... )
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
precomputed_enc_h: Tensor | None
reset()[source]

Reset the memory in attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters:
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

Return type:

The output of the attention module.

class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]

Bases: Module

This class implements a single-headed key-value attention module for seq2seq learning.

Reference: β€œAttention Is All You Need” by Vaswani et al., sec. 3.2.1

Parameters:
  • enc_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • dec_dim (int) – Size of the decoder feature vectors from which queries are computed.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = KeyValueAttention(
...     enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5
... )
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
reset()[source]

Reset the memory in the attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters:
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

Return type:

The output of the attention module.

class speechbrain.nnet.attention.RelPosEncXL(emb_dim: int, dtype: dtype = torch.float32)[source]

Bases: Module

Relative positional encoding for the RelPosMHAXL.

Parameters:
  • emb_dim (int) – Size of the embedding, which controls the size of the last dimension of the positional embedding as well

  • dtype (torch.dtype, optional) – If unspecified, defaults to torch.float32. Controls the data type of the output embedding (but does not affect the precision of the computations, which remain torch.float32).

make_pe(seq_len: int)[source]

Builds the positional embedding tensor for a given sequence length.

Parameters:

seq_len (int) – The length of the sequence to create the position embedding for.

Returns:

Positional embedding tensor of shape [1, 2*seq_len-1, embed_dim]

Return type:

torch.Tensor

forward(x: Tensor)[source]

Builds the positional embedding tensor. Similar to make_pe() but uses the shape information from the provided tensor.

Parameters:

x (torch.Tensor) – input tensor with shape batch_size, seq_len, embed_dim

Returns:

pos_emb – Positional embedding tensor of shape [1, 2*seq_len-1, embed_dim]

Return type:

torch.Tensor

class speechbrain.nnet.attention.RelPosMHAXL(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None, mask_pos_future=False)[source]

Bases: Module

This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf

Parameters:
  • embed_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • num_heads (int) – Number of attention heads.

  • dropout (float, optional) – Dropout rate.

  • vbias (bool, optional) – Whether to use bias for computing value.

  • vdim (int, optional) – Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).

  • mask_pos_future (bool, optional) – Whether to mask future positional encodings values. Must be true for causal applications e.g. decoder.

Example

>>> inputs = torch.rand([6, 60, 512])
>>> pos_emb = torch.rand([1, 2 * 60 - 1, 512])
>>> net = RelPosMHAXL(num_heads=8, embed_dim=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs, pos_emb)
>>> outputs.shape
torch.Size([6, 60, 512])
rel_shift(x)[source]

Relative shift implementation.

forward(query, key, value, pos_embs, key_padding_mask=None, attn_mask=None, return_attn_weights=True)[source]

Compute attention.

Parameters:
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • pos_embs (torch.Tensor) – bidirectional sinusoidal positional embedding tensor (1, 2*S-1, E) where S is the max length between source and target sequence lengths, and E is the embedding dimension.

  • key_padding_mask (torch.Tensor) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask (torch.Tensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • return_attn_weights (bool) – Whether to additionally return the attention weights.

Returns:

  • out (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_score (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.

class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]

Bases: Module

The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.

Reference: https://pytorch.org/docs/stable/nn.html

Parameters:
  • nhead (int) – parallel attention heads.

  • d_model (int) – The size of the model layers.

  • dropout (float) – a Dropout layer on attn_output_weights (default: 0.0).

  • bias (bool) – add bias as module parameter (default: True).

  • add_bias_kv (bool) – add bias to the key and value sequences at dim=0.

  • add_zero_attn (bool) – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim (int) – total number of features in key (default: None).

  • vdim (int) – total number of features in value (default: None).

Example

>>> inputs = torch.rand([8, 60, 512])
>>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool = True, pos_embs: Tensor | None = None)[source]

Compute attention.

Parameters:
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • attn_mask (torch.Tensor, optional) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • key_padding_mask (torch.Tensor, optional) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • return_attn_weights (bool, optional) – True to additionally return the attention weights, False otherwise.

  • pos_embs (torch.Tensor, optional) – Positional embeddings added to the attention map of shape (L, S, E) or (L, S, 1).

Returns:

  • attn_output (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_output_weights (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length. This is returned only if return_attn_weights=True (True by default).

class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.0, activation: type = <class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: Module

The class implements the positional-wise feed forward module in β€œAttention Is All You Need”.

Parameters:
  • d_ffn (int) – Hidden layer size.

  • input_shape (tuple, optional) – Expected shape of the input. Alternatively use input_size.

  • input_size (int, optional) – Expected size of the input. Alternatively use input_shape.

  • dropout (float, optional) – Dropout rate.

  • activation (torch.nn.Module, optional) – activation functions to be applied (Recommendation: ReLU, GELU).

Example

>>> inputs = torch.rand([8, 60, 512])
>>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1])
>>> outputs = net(inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
forward(x)[source]

Applies PositionalwiseFeedForward to the input tensor x.

class speechbrain.nnet.attention.PrecomputedRoPESinusoids(max_length: int, input_size: int, dtype: dtype, device: device)[source]

Bases: Module

A cache for the sines and cosines needed to rotate the vectors for rotary position embeddings (RoPE). This stores the nonzero entries from eq(15) from https://arxiv.org/pdf/2104.09864

Parameters:
  • max_length (int) – The allowed max length of the input sequence. For a fixed setting of the other arguments, the computation takes O(max_length) time.

  • input_size (int) – Size of each vector in the input sequence, i.e. the dimension of each attention head.

  • dtype (torch.dtype) – The dtype of the tensors.

  • device (torch.device) – The Torch device to put the tensors on.

Example

>>> precomputed = PrecomputedRoPESinusoids(
...     3, 8, torch.float32, torch.device("cpu")
... )
>>> precomputed.cosines.shape
torch.Size([3, 8])
>>> precomputed.sines.shape == precomputed.cosines.shape
True
>>> precomputed.cosines
tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.5403,  0.9950,  0.9950,  0.9999,  0.9999,  1.0000,  1.0000],
        [-0.4161, -0.4161,  0.9801,  0.9801,  0.9998,  0.9998,  1.0000,  1.0000]])
>>> precomputed.sines
tensor([[-0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
        [-0.8415,  0.8415, -0.0998,  0.0998, -0.0100,  0.0100, -0.0010,  0.0010],
        [-0.9093,  0.9093, -0.1987,  0.1987, -0.0200,  0.0200, -0.0020,  0.0020]])
>>> precomputed.index_swap
tensor([1, 0, 3, 2, 5, 4, 7, 6])
class speechbrain.nnet.attention.MemoiseAtLeastSize(function: Callable, round_up: Callable[[Any], Any])[source]

Bases: object

Memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with.

Parameters:
  • function (Callable) – The function to call.

  • round_up (Callable[[Any], Any]) – A function that rounds up. The fewer values this rounds up to, the less likely it is that the function will be called repeatedly.

speechbrain.nnet.attention.memoise_at_least(round_up: Callable[[Any], Any]) Callable[[Callable], MemoiseAtLeastSize][source]

Decorator that memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with. If the memo has stored the result from a matching previous function call, The stored result will be returned instead of calling the function again.

Parameters:

round_up (Callable[[Any], Any]) – A function that rounds up. This will be called with the first argument passed in. The underlying function will receive, instead of this first argument, the rounded-up version. The fewer values this rounds up to, the less likely it is that the function will be called repeatedly.

Return type:

The passed function but with MemoiseAtLeastSize capability.

class speechbrain.nnet.attention.RoPEMHA(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None)[source]

Bases: Module

This is an implementation of multihead self-attention with RoPE positional embeddings. As it relies on Torch for self-attention, it is significantly faster than RelPosMHAXL while offering the same or better levels of accuracy.

Details about RoPE: https://arxiv.org/pdf/2104.09864.

Parameters:
  • embed_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • num_heads (int) – Number of attention heads.

  • dropout (float, optional) – Dropout rate.

  • vbias (bool, optional) – Whether to use bias for computing value.

  • vdim (int, optional) – Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).

Example

>>> max_len = 64
>>> inputs = torch.rand([6, 60, 512])
>>> num_heads = 8
>>> net = RoPEMHA(num_heads=num_heads, embed_dim=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs)
>>> outputs.shape
torch.Size([6, 60, 512])
forward(query, key, value, key_padding_mask=None, attn_mask=None, pos_embs=None, return_attn_weights=True)[source]

Compute attention through Pytorch attention.

Parameters:
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • key_padding_mask (torch.Tensor) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask (torch.BoolTensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. The positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • pos_embs (torch.Tensor) – Not used by this class. It is kept for compliance.

  • return_attn_weights (bool) – Whether to additionally return the attention weights.

Returns:

  • out (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_score (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.

speechbrain.nnet.attention.masks_union(bsz, klen, num_heads, attn_mask, key_padding_mask)[source]

This is an utility function combining standard key_padding_mask and attn_mask from SpeechBrain into a single one for scaled_dot_product_attention. This function does not support weighting of the attn_score. Hence, if one wish to use float values as masks, they should not use this function.

Parameters:
  • bsz (int) – Batch size dimension.

  • klen (int) – Time dimension of the key tensor. (Sequence length).

  • num_heads (int) – Number of heads of the attention module using these masks.

  • attn_mask (torch.BoolTensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. The positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • key_padding_mask (torch.BoolTensor) – (B, S) where B is the batch size, S is the source sequence length. The positions with the value of True will be ignored while the position with the value of False will be unchanged.

Returns:

out – (bsz, num_heads, klen, klen) where False values are masked and True are unmasked (opposite of the input tensors).

Return type:

torch.BoolTensor