speechbrain.nnet.attention module

Library implementing attention modules.

Authors
  • Ju-Chieh Chou 2020

  • Jianyuan Zhong 2020

  • Loren Lugosch 2020

  • Samuele Cornell 2020

Summary

Classes:

ContentBasedAttention

This class implements content-based attention module for seq2seq learning.

KeyValueAttention

This class implements a single-headed key-value attention module for seq2seq learning.

LocationAwareAttention

This class implements location-aware attention module for seq2seq learning.

MultiheadAttention

The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.

PositionalwiseFeedForward

The class implements the positional-wise feed forward module in “Attention Is All You Need”.

RelPosEncXL

RelPosMHAXL

This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf

Reference

class speechbrain.nnet.attention.ContentBasedAttention(enc_dim, dec_dim, attn_dim, output_dim, scaling=1.0)[source]

Bases: Module

This class implements content-based attention module for seq2seq learning.

Reference: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, Bahdanau et.al. https://arxiv.org/pdf/1409.0473.pdf

Parameters
  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

  • scaling (float) – The factor controls the sharpening degree (default: 1.0).

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = ContentBasedAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5)
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
reset()[source]

Reset the memory in the attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

training: bool
class speechbrain.nnet.attention.LocationAwareAttention(enc_dim, dec_dim, attn_dim, output_dim, conv_channels, kernel_size, scaling=1.0)[source]

Bases: Module

This class implements location-aware attention module for seq2seq learning.

Reference: Attention-Based Models for Speech Recognition, Chorowski et.al. https://arxiv.org/pdf/1506.07503.pdf

Parameters
  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

  • conv_channels (int) – Number of channel for location feature.

  • kernel_size (int) – Kernel size of convolutional layer for location feature.

  • scaling (float) – The factor controls the sharpening degree (default: 1.0).

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = LocationAwareAttention(
...     enc_dim=20,
...     dec_dim=25,
...     attn_dim=30,
...     output_dim=5,
...     conv_channels=10,
...     kernel_size=100)
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
precomputed_enc_h: Optional[Tensor]
reset()[source]

Reset the memory in attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

class speechbrain.nnet.attention.KeyValueAttention(enc_dim, dec_dim, attn_dim, output_dim)[source]

Bases: Module

This class implements a single-headed key-value attention module for seq2seq learning.

Reference: “Attention Is All You Need” by Vaswani et al., sec. 3.2.1

Parameters
  • enc_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • dec_dim (int) – Size of the decoder feature vectors from which queries are computed.

  • attn_dim (int) – Size of the attention feature.

  • output_dim (int) – Size of the output context vector.

Example

>>> enc_tensor = torch.rand([4, 10, 20])
>>> enc_len = torch.ones([4]) * 10
>>> dec_tensor = torch.rand([4, 25])
>>> net = KeyValueAttention(enc_dim=20, dec_dim=25, attn_dim=30, output_dim=5)
>>> out_tensor, out_weight = net(enc_tensor, enc_len, dec_tensor)
>>> out_tensor.shape
torch.Size([4, 5])
reset()[source]

Reset the memory in the attention module.

forward(enc_states, enc_len, dec_states)[source]

Returns the output of the attention module.

Parameters
  • enc_states (torch.Tensor) – The tensor to be attended.

  • enc_len (torch.Tensor) – The real length (without padding) of enc_states for each sentence.

  • dec_states (torch.Tensor) – The query tensor.

training: bool
class speechbrain.nnet.attention.RelPosEncXL(emb_dim)[source]

Bases: Module

forward(x: Tensor)[source]
Parameters
  • x (torch.Tensor) –

  • batch_size (input tensor with shape) –

  • seq_len

  • embed_dim

Returns

pos_emb

Return type

torch.Tensor

training: bool
class speechbrain.nnet.attention.RelPosMHAXL(embed_dim, num_heads, dropout=0.0, vbias=False, vdim=None, mask_pos_future=False)[source]

Bases: Module

This class implements the relative multihead implementation similar to that in Transformer XL https://arxiv.org/pdf/1901.02860.pdf

Parameters
  • embed_dim (int) – Size of the encoder feature vectors from which keys and values are computed.

  • num_heads (int) – Number of attention heads.

  • dropout (float, optional) – Dropout rate.

  • vbias (bool, optional) – Whether to use bias for computing value.

  • vdim (int, optional) – Size for value. Default is embed_dim (Note each head is embed_dim // num_heads).

  • mask_pos_future (bool, optional) – Whether to mask future positional encodings values. Must be true for causal applications e.g. decoder.

Example

>>> inputs = torch.rand([6, 60, 512])
>>> pos_emb = torch.rand([1, 2*60-1, 512])
>>> net = RelPosMHAXL(num_heads=8, embed_dim=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs, pos_emb)
>>> outputs.shape
torch.Size([6, 60, 512])
rel_shift(x)[source]

Relative shift implementation.

forward(query, key, value, pos_embs, key_padding_mask=None, attn_mask=None, return_attn_weights=True)[source]
Parameters
  • query (tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • pos_emb (tensor) – bidirectional sinusoidal positional embedding tensor (1, 2*S-1, E) where S is the max length between source and target sequence lengths, and E is the embedding dimension.

  • key_padding_mask (tensor) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask (tensor) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • Outputs

  • -------

  • out (tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_score (tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.

training: bool
class speechbrain.nnet.attention.MultiheadAttention(nhead, d_model, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]

Bases: Module

The class is a wrapper of MultiHead Attention for torch.nn.MultiHeadAttention.

Reference: https://pytorch.org/docs/stable/nn.html

Parameters
  • num_heads (int) – parallel attention heads.

  • dropout (float) – a Dropout layer on attn_output_weights (default: 0.0).

  • bias (bool) – add bias as module parameter (default: True).

  • add_bias_kv (bool) – add bias to the key and value sequences at dim=0.

  • add_zero_attn (bool) – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim (int) – total number of features in key (default: None).

  • vdim (int) – total number of features in value (default: None).

Example

>>> inputs = torch.rand([8, 60, 512])
>>> net = MultiheadAttention(nhead=8, d_model=inputs.shape[-1])
>>> outputs, attn = net(inputs, inputs, inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
forward(query, key, value, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, return_attn_weights: Optional[Tensor] = True, pos_embs: Optional[Tensor] = None)[source]
Parameters
  • query (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • key (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • value (torch.Tensor) – (B, S, E) where S is the source sequence length, B is the batch size, E is the embedding dimension.

  • key_padding_mask (torch.Tensor, optional) – (B, S) where B is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask (torch.Tensor, optional) – 2D mask (L, S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  • pos_embs (torch.Tensor, optional) – Positional embeddings added to the attention map of shape (L, S, E) or (L, S, 1).

  • Outputs

  • -------

  • attn_output (torch.Tensor) – (B, L, E) where L is the target sequence length, B is the batch size, E is the embedding dimension.

  • attn_output_weights (torch.Tensor) – (B, L, S) where B is the batch size, L is the target sequence length, S is the source sequence length.

training: bool
class speechbrain.nnet.attention.PositionalwiseFeedForward(d_ffn, input_shape=None, input_size=None, dropout=0.0, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: Module

The class implements the positional-wise feed forward module in “Attention Is All You Need”.

Parameters
  • d_ffn (int) – Hidden layer size.

  • input_shape (tuple, optional) – Expected shape of the input. Alternatively use input_size.

  • input_size (int, optional) – Expected size of the input. Alternatively use input_shape.

  • dropout (float, optional) – Dropout rate.

  • activation (torch.nn.Module, optional) – activation functions to be applied (Recommendation: ReLU, GELU).

Example

>>> inputs = torch.rand([8, 60, 512])
>>> net = PositionalwiseFeedForward(256, input_size=inputs.shape[-1])
>>> outputs = net(inputs)
>>> outputs.shape
torch.Size([8, 60, 512])
training: bool
forward(x)[source]

Applies PositionalwiseFeedForward to the input tensor x.