speechbrain.nnet.attention module
Library implementing attention modules.
- Authors
Ju-Chieh Chou 2020
Jianyuan Zhong 2020
Loren Lugosch 2020
Samuele Cornell 2020
Summary
Classes:
This class implements content-based attention module for seq2seq learning. |
|
This class implements a single-headed key-value attention module for seq2seq learning. |
|
This class implements location-aware attention module for seq2seq learning. |
|
The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention. |
|
The class implements the positional-wise feed forward module in “Attention Is All You Need”. |
|
This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf |
Reference
- class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]
Bases:
Module
This class implements content-based attention module for seq2seq learning.
Reference: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, Bahdanau et.al. https://arxiv.org/pdf/1409.0473.pdf
- Parameters
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = ContentBasedAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]
Returns the output of the attention module.
- Parameters
enc_states (torch.Tensor) – The tensor to be attended.
enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) – The query tensor.
- class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]
Bases:
Module
This class implements location-aware attention module for seq2seq learning.
Reference: Attention-Based Models for Speech Recognition, Chorowski et.al. https://arxiv.org/pdf/1506.07503.pdf
- Parameters
attn_dim (int) – Size of the attention feature.
output_dim (int) – Size of the output context vector.
conv_channels (int) – Number of channel for location feature.
kernel_size (int) – Kernel size of convolutional layer for location feature.
scaling (float) – The factor controls the sharpening degree (default: 1.0).
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = LocationAwareAttention( ... enc_dim=20, ... dec_dim=25, ... attn_dim=30, ... output_dim=5, ... conv_channels=10, ... kernel_size=100) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]
Returns the output of the attention module.
- Parameters
enc_states (torch.Tensor) – The tensor to be attended.
enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) – The query tensor.
- class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]
Bases:
Module
This class implements a single-headed key-value attention module for seq2seq learning.
Reference: “Attention Is All You Need” by Vaswani et al., sec. 3.2.1
- Parameters
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = KeyValueAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]
Returns the output of the attention module.
- Parameters
enc_states (torch.Tensor) – The tensor to be attended.
enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) – The query tensor.
- class speechbrain.nnet.attention.RelPosEncXL(emb_dim)[source]
Bases:
Module
- forward(x: Tensor)[source]
- Parameters
x (torch.Tensor) –
batch_size (input tensor with shape) –
seq_len –
embed_dim –
- Returns
pos_emb
- Return type
- class speechbrain.nnet.attention.RelPosMHAXL(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None, mask_pos_future=False)[source]
Bases:
Module
This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf
- Parameters
embed_dim (int) – Size of the encoder feature vectors from which keys and values are computed.
num_heads (int) – Number of attention heads.
dropout (float, optional) – Dropout rate.
vbias (bool, optional) – Whether to use bias for computing value.
vdim (int, optional) – Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).
mask_pos_future (bool, optional) – Whether to mask future positional encodings values. Must be true for causal applications e.g. decoder.
Example
>>> inputs = torch.rand([6, 60, 512]) >>> pos_emb = torch.rand([1, 2*60-1, 512]) >>> net = RelPosMHAXL(num_heads=8, embed_dim=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs, pos_emb) >>> outputs.shape torch.Size([6, 60, 512])
- forward(query, key, value, pos_embs, key_padding_mask=None, attn_mask=None, return_attn_weights=True)[source]
- Parameters
query (tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
key (tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
value (tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
pos_emb (tensor) – bidirectional sinusoidal positional embedding tensor (1, 2*S-1, E) where S is the max length between source and target sequence lengths, and E is the embedding dimension.
key_padding_mask (tensor) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.
attn_mask (tensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
Outputs –
------- –
out (tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
attn_score (tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.
- class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]
Bases:
Module
The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.
Reference: https://pytorch.org/docs/stable/nn.html
- Parameters
num_heads (int) – parallel attention heads.
dropout (float) – a Dropout layer on attn_output_weights (default: 0.0).
bias (bool) – add bias as module parameter (default: True).
add_bias_kv (bool) – add bias to the key and value sequences at dim=0.
add_zero_attn (bool) – add a new batch of zeros to the key and value sequences at dim=1.
kdim (int) – total number of features in key (default: None).
vdim (int) – total number of features in value (default: None).
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512])
- forward(query, key, value, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, return_attn_weights: Optional[Tensor] = True, pos_embs: Optional[Tensor] = None)[source]
- Parameters
query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.
key_padding_mask (torch.Tensor, optional) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.
attn_mask (torch.Tensor, optional) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
pos_embs (torch.Tensor, optional) – Positional embeddings added to the attention map of shape (L, S, E) or (L, S, 1).
Outputs –
------- –
attn_output (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.
attn_output_weights (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.
- class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]
Bases:
Module
The class implements the positional-wise feed forward module in “Attention Is All You Need”.
- Parameters
d_ffn (int) – Hidden layer size.
input_shape (tuple, optional) – Expected shape of the input. Alternatively use
input_size
.input_size (int, optional) – Expected size of the input. Alternatively use
input_shape
.dropout (float, optional) – Dropout rate.
activation (torch.nn.Module, optional) – activation functions to be applied (Recommendation: ReLU, GELU).
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1]) >>> outputs = net(inputs) >>> outputs.shape torch.Size([8, 60, 512])