speechbrain.nnet.attention moduleο
Library implementing attention modules.
- Authors
Ju-Chieh Chou 2020
Jianyuan Zhong 2020
Loren Lugosch 2020
Samuele Cornell 2020
Shucong Zhang 2024
Summaryο
Classes:
This class implements content-based attention module for seq2seq learning. |
|
This class implements a single-headed key-value attention module for seq2seq learning. |
|
This class implements location-aware attention module for seq2seq learning. |
|
Memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with. |
|
The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention. |
|
The class implements the positional-wise feed forward module in βAttention Is All You Needβ. |
|
A cache for the sines and cosines needed to rotate the vectors for rotary position embeddings (RoPE). |
|
Relative positional encoding for the |
|
This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf |
|
This is an implementation of multihead self-attention with RoPE positional embeddings. |
Functions:
This is an utility function combining standard key_padding_mask and attn_mask from SpeechBrain into a single one for scaled_dot_product_attention. |
|
Decorator that memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with. |
Referenceο
- class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]ο
Bases:
ModuleThis class implements content-based attention module for seq2seq learning.
Reference: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, Bahdanau et.al. https://arxiv.org/pdf/1409.0473.pdf
- Parameters:
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = ContentBasedAttention( ... enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5 ... ) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]ο
Returns the output of the attention module.
- Parameters:
enc_states (torch.Tensor) β The tensor to be attended.
enc_len (torch.Tensor) β The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) β The query tensor.
- Return type:
The output of the attention module.
- class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]ο
Bases:
ModuleThis class implements location-aware attention module for seq2seq learning.
Reference: Attention-Based Models for Speech Recognition, Chorowski et.al. https://arxiv.org/pdf/1506.07503.pdf
- Parameters:
enc_dim (int) β Size of encoder.
dec_dim (int) β Size of decoder.
attn_dim (int) β Size of the attention feature.
output_dim (int) β Size of the output context vector.
conv_channels (int) β Number of channel for location feature.
kernel_size (int) β Kernel size of convolutional layer for location feature.
scaling (float) β The factor controls the sharpening degree (default: 1.0).
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = LocationAwareAttention( ... enc_dim=20, ... dec_dim=25, ... attn_dim=30, ... output_dim=5, ... conv_channels=10, ... kernel_size=100, ... ) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]ο
Returns the output of the attention module.
- Parameters:
enc_states (torch.Tensor) β The tensor to be attended.
enc_len (torch.Tensor) β The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) β The query tensor.
- Return type:
The output of the attention module.
- class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]ο
Bases:
ModuleThis class implements a single-headed key-value attention module for seq2seq learning.
Reference: βAttention Is All You Needβ by Vaswani et al., sec. 3.2.1
- Parameters:
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = KeyValueAttention( ... enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5 ... ) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]ο
Returns the output of the attention module.
- Parameters:
enc_states (torch.Tensor) β The tensor to be attended.
enc_len (torch.Tensor) β The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) β The query tensor.
- Return type:
The output of the attention module.
- class speechbrain.nnet.attention.RelPosEncXL(emb_dim: int, dtype: dtype = torch.float32)[source]ο
Bases:
ModuleRelative positional encoding for the
RelPosMHAXL.- Parameters:
emb_dim (int) β Size of the embedding, which controls the size of the last dimension of the positional embedding as well
dtype (torch.dtype, optional) β If unspecified, defaults to
torch.float32. Controls the data type of the output embedding (but does not affect the precision of the computations, which remaintorch.float32).
- make_pe(seq_len: int)[source]ο
Builds the positional embedding tensor for a given sequence length.
- Parameters:
seq_len (int) β The length of the sequence to create the position embedding for.
- Returns:
Positional embedding tensor of shape
[1, 2*seq_len-1, embed_dim]- Return type:
- forward(x: Tensor)[source]ο
Builds the positional embedding tensor. Similar to
make_pe()but uses the shape information from the provided tensor.- Parameters:
x (torch.Tensor) β input tensor with shape batch_size, seq_len, embed_dim
- Returns:
pos_emb β Positional embedding tensor of shape
[1, 2*seq_len-1, embed_dim]- Return type:
- class speechbrain.nnet.attention.RelPosMHAXL(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None, mask_pos_future=False)[source]ο
Bases:
ModuleThis class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf
- Parameters:
embed_dim (int) β Size of the encoder feature vectors from which keys and values are computed.
num_heads (int) β Number of attention heads.
dropout (float, optional) β Dropout rate.
vbias (bool, optional) β Whether to use bias for computing value.
vdim (int, optional) β Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).
mask_pos_future (bool, optional) β Whether to mask future positional encodings values. Must be true for causal applications e.g. decoder.
Example
>>> inputs = torch.rand([6, 60, 512]) >>> pos_emb = torch.rand([1, 2 * 60 - 1, 512]) >>> net = RelPosMHAXL(num_heads=8, embed_dim=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs, pos_emb) >>> outputs.shape torch.Size([6, 60, 512])
- forward(query, key, value, pos_embs, key_padding_mask=None, attn_mask=None, return_attn_weights=True)[source]ο
Compute attention.
- Parameters:
query (torch.Tensor) β (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
key (torch.Tensor) β (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
value (torch.Tensor) β (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
pos_embs (torch.Tensor) β bidirectional sinusoidal positional embedding tensor (1, 2*S-1, E) where S is the max length between source and target sequence lengths, and E is the embedding dimension.
key_padding_mask (torch.Tensor) β (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.
attn_mask (torch.Tensor) β 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
return_attn_weights (bool) β Whether to additionally return the attention weights.
- Returns:
out (torch.Tensor) β (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
attn_score (torch.Tensor) β (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.
- class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]ο
Bases:
ModuleThe class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.
Reference: https://pytorch.org/docs/stable/nn.html
- Parameters:
nhead (int) β parallel attention heads.
d_model (int) β The size of the model layers.
dropout (float) β a Dropout layer on attn_output_weights (default: 0.0).
bias (bool) β add bias as module parameter (default: True).
add_bias_kv (bool) β add bias to the key and value sequences at dim=0.
add_zero_attn (bool) β add a new batch of zeros to the key and value sequences at dim=1.
kdim (int) β total number of features in key (default: None).
vdim (int) β total number of features in value (default: None).
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512])
- forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool = True, pos_embs: Tensor | None = None)[source]ο
Compute attention.
- Parameters:
query (torch.Tensor) β (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
key (torch.Tensor) β (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
value (torch.Tensor) β (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
attn_mask (torch.Tensor, optional) β 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
key_padding_mask (torch.Tensor, optional) β (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.
return_attn_weights (bool, optional) β True to additionally return the attention weights, False otherwise.
pos_embs (torch.Tensor, optional) β Positional embeddings added to the attention map of shape (L, S, E) or (L, S, 1).
- Returns:
attn_output (torch.Tensor) β (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
attn_output_weights (torch.Tensor) β (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length. This is returned only if
return_attn_weights=True(True by default).
- class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.0, activation: type = <class 'torch.nn.modules.activation.ReLU'>)[source]ο
Bases:
ModuleThe class implements the positional-wise feed forward module in βAttention Is All You Needβ.
- Parameters:
d_ffn (int) β Hidden layer size.
input_shape (tuple, optional) β Expected shape of the input. Alternatively use
input_size.input_size (int, optional) β Expected size of the input. Alternatively use
input_shape.dropout (float, optional) β Dropout rate.
activation (torch.nn.Module, optional) β activation functions to be applied (Recommendation: ReLU, GELU).
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1]) >>> outputs = net(inputs) >>> outputs.shape torch.Size([8, 60, 512])
- class speechbrain.nnet.attention.PrecomputedRoPESinusoids(max_length: int, input_size: int, dtype: dtype, device: device)[source]ο
Bases:
ModuleA cache for the sines and cosines needed to rotate the vectors for rotary position embeddings (RoPE). This stores the nonzero entries from eq(15) from https://arxiv.org/pdf/2104.09864
- Parameters:
max_length (int) β The allowed max length of the input sequence. For a fixed setting of the other arguments, the computation takes O(max_length) time.
input_size (int) β Size of each vector in the input sequence, i.e. the dimension of each attention head.
dtype (torch.dtype) β The dtype of the tensors.
device (torch.device) β The Torch device to put the tensors on.
Example
>>> precomputed = PrecomputedRoPESinusoids( ... 3, 8, torch.float32, torch.device("cpu") ... ) >>> precomputed.cosines.shape torch.Size([3, 8]) >>> precomputed.sines.shape == precomputed.cosines.shape True >>> precomputed.cosines tensor([[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [ 0.5403, 0.5403, 0.9950, 0.9950, 0.9999, 0.9999, 1.0000, 1.0000], [-0.4161, -0.4161, 0.9801, 0.9801, 0.9998, 0.9998, 1.0000, 1.0000]]) >>> precomputed.sines tensor([[-0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -0.0000, 0.0000], [-0.8415, 0.8415, -0.0998, 0.0998, -0.0100, 0.0100, -0.0010, 0.0010], [-0.9093, 0.9093, -0.1987, 0.1987, -0.0200, 0.0200, -0.0020, 0.0020]]) >>> precomputed.index_swap tensor([1, 0, 3, 2, 5, 4, 7, 6])
- class speechbrain.nnet.attention.MemoiseAtLeastSize(function: Callable, round_up: Callable[[Any], Any])[source]ο
Bases:
objectMemoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with.
- Parameters:
function (Callable) β The function to call.
round_up (Callable[[Any], Any]) β A function that rounds up. The fewer values this rounds up to, the less likely it is that the function will be called repeatedly.
- speechbrain.nnet.attention.memoise_at_least(round_up: Callable[[Any], Any]) Callable[[Callable], MemoiseAtLeastSize][source]ο
Decorator that memoises a function which has as its first argument a value that indicates a minimum value to call the underlying function with. If the memo has stored the result from a matching previous function call, The stored result will be returned instead of calling the function again.
- Parameters:
round_up (Callable[[Any], Any]) β A function that rounds up. This will be called with the first argument passed in. The underlying function will receive, instead of this first argument, the rounded-up version. The fewer values this rounds up to, the less likely it is that the function will be called repeatedly.
- Return type:
The passed function but with MemoiseAtLeastSize capability.
- class speechbrain.nnet.attention.RoPEMHA(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None)[source]ο
Bases:
ModuleThis is an implementation of multihead self-attention with RoPE positional embeddings. As it relies on Torch for self-attention, it is significantly faster than RelPosMHAXL while offering the same or better levels of accuracy.
Details about RoPE: https://arxiv.org/pdf/2104.09864.
- Parameters:
embed_dim (int) β Size of the encoder feature vectors from which keys and values are computed.
num_heads (int) β Number of attention heads.
dropout (float, optional) β Dropout rate.
vbias (bool, optional) β Whether to use bias for computing value.
vdim (int, optional) β Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).
Example
>>> max_len = 64 >>> inputs = torch.rand([6, 60, 512]) >>> num_heads = 8 >>> net = RoPEMHA(num_heads=num_heads, embed_dim=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([6, 60, 512])
- forward(query, key, value, key_padding_mask=None, attn_mask=None, pos_embs=None, return_attn_weights=True)[source]ο
Compute attention through Pytorch attention.
- Parameters:
query (torch.Tensor) β (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
key (torch.Tensor) β (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
value (torch.Tensor) β (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
key_padding_mask (torch.Tensor) β (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.
attn_mask (torch.BoolTensor) β 2D mask (L, S) where L is the target sequence length, S is the source sequence length. The positions with the value of True will be ignored while the position with the value of False will be unchanged.
pos_embs (torch.Tensor) β Not used by this class. It is kept for compliance.
return_attn_weights (bool) β Whether to additionally return the attention weights.
- Returns:
out (torch.Tensor) β (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
attn_score (torch.Tensor) β (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.
- speechbrain.nnet.attention.masks_union(bsz, klen, num_heads, attn_mask, key_padding_mask)[source]ο
This is an utility function combining standard key_padding_mask and attn_mask from SpeechBrain into a single one for scaled_dot_product_attention. This function does not support weighting of the attn_score. Hence, if one wish to use float values as masks, they should not use this function.
- Parameters:
bsz (int) β Batch size dimension.
klen (int) β Time dimension of the key tensor. (Sequence length).
num_heads (int) β Number of heads of the attention module using these masks.
attn_mask (torch.BoolTensor) β 2D mask (L, S) where L is the target sequence length, S is the source sequence length. The positions with the value of True will be ignored while the position with the value of False will be unchanged.
key_padding_mask (torch.BoolTensor) β (B, S) where B is the batch size, S is the source sequence length. The positions with the value of True will be ignored while the position with the value of False will be unchanged.
- Returns:
out β (bsz, num_heads, klen, klen) where False values are masked and True are unmasked (opposite of the input tensors).
- Return type:
torch.BoolTensor