speechbrain.nnet.attention module¶
Library implementing attention modules.
- Authors
Ju-Chieh Chou 2020
Jianyuan Zhong 2020
Loren Lugosch 2020
Summary¶
Classes:
This class implements content-based attention module for seq2seq learning. |
|
This class implements a single-headed key-value attention module for seq2seq learning. |
|
This class implements location-aware attention module for seq2seq learning. |
|
The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention. |
|
The class implements the positional-wise feed forward module in “Attention Is All You Need”. |
Reference¶
- class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]¶
Bases:
torch.nn.modules.module.Module
This class implements content-based attention module for seq2seq learning.
Reference: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, Bahdanau et.al. https://arxiv.org/pdf/1409.0473.pdf
- Parameters
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = ContentBasedAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]¶
Returns the output of the attention module.
- Parameters
enc_states (torch.Tensor) – The tensor to be attended.
enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) – The query tensor.
- class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]¶
Bases:
torch.nn.modules.module.Module
This class implements location-aware attention module for seq2seq learning.
Reference: Attention-Based Models for Speech Recognition, Chorowski et.al. https://arxiv.org/pdf/1506.07503.pdf
- Parameters
attn_dim (int) – Size of the attention feature.
output_dim (int) – Size of the output context vector.
conv_channels (int) – Number of channel for location feature.
kernel_size (int) – Kernel size of convolutional layer for location feature.
scaling (float) – The factor controls the sharpening degree (default: 1.0).
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = LocationAwareAttention( ... enc_dim=20, ... dec_dim=25, ... attn_dim=30, ... output_dim=5, ... conv_channels=10, ... kernel_size=100) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- precomputed_enc_h: Optional[torch.Tensor]¶
- forward(enc_states, enc_len, dec_states)[source]¶
Returns the output of the attention module.
- Parameters
enc_states (torch.Tensor) – The tensor to be attended.
enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) – The query tensor.
- class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]¶
Bases:
torch.nn.modules.module.Module
This class implements a single-headed key-value attention module for seq2seq learning.
Reference: “Attention Is All You Need” by Vaswani et al., sec. 3.2.1
- Parameters
Example
>>> enc_tensor = torch.rand([4, 10, 20]) >>> enc_len = torch.ones([4]) * 10 >>> dec_tensor = torch.rand([4, 25]) >>> net = KeyValueAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5) >>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor) >>> out_tensor.shape torch.Size([4, 5])
- forward(enc_states, enc_len, dec_states)[source]¶
Returns the output of the attention module.
- Parameters
enc_states (torch.Tensor) – The tensor to be attended.
enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.
dec_states (torch.Tensor) – The query tensor.
- class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]¶
Bases:
torch.nn.modules.module.Module
The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.
Reference: https://pytorch.org/docs/stable/nn.html
- Parameters
num_heads (int) – parallel attention heads.
dropout (float) – a Dropout layer on attn_output_weights (default: 0.0).
bias (bool) – add bias as module parameter (default: True).
add_bias_kv (bool) – add bias to the key and value sequences at dim=0.
add_zero_attn (bool) – add a new batch of zeros to the key and value sequences at dim=1.
kdim (int) – total number of features in key (default: None).
vdim (int) – total number of features in value (default: None).
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1]) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512])
- forward(query, key, value, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None)[source]¶
- Parameters
query (tensor) – (N, L, E) where L is the target sequence length, N is the batch size, E is the embedding dimension.
key (tensor) – (N, S, E) where S is the source sequence length, N is the batch size, E is the embedding dimension.
value (tensor) – (N, S, E) where S is the source sequence length, N is the batch size, E is the embedding dimension.
key_padding_mask (tensor) – (N, S) where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.
attn_mask (tensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
Outputs –
------- –
attn_output (tensor) – (L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension.
attn_output_weights (tensor) – (N, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length.
- class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]¶
Bases:
torch.nn.modules.module.Module
The class implements the positional-wise feed forward module in “Attention Is All You Need”.
- Parameters
d_ffn (int) – Dimension of representation space of this positional-wise feed forward module.
input_shape (tuple) – Expected shape of the input. Alternatively use
input_size
.input_size (int) – Expected size of the input. Alternatively use
input_shape
.dropout (float) – Fraction of outputs to drop.
activation (torch class) – activation functions to be applied (Recommendation: ReLU, GELU).
Example
>>> inputs = torch.rand([8, 60, 512]) >>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1]) >>> outputs = net(inputs) >>> outputs.shape torch.Size([8, 60, 512])