Source code for speechbrain.lobes.models.g2p.model
"""The Attentional RNN model for Grapheme-to-Phoneme
Authors
* Mirco Ravanelli 2021
* Artem Ploujnikov 2021
"""
import torch
from torch import nn
from speechbrain.lobes.models.transformer.Transformer import (
TransformerInterface,
get_key_padding_mask,
get_lookahead_mask,
)
from speechbrain.nnet import normalization
from speechbrain.nnet.linear import Linear
[docs]
class AttentionSeq2Seq(nn.Module):
"""
The Attentional RNN encoder-decoder model
Arguments
---------
enc: torch.nn.Module
the encoder module
encoder_emb: torch.nn.Module
the encoder_embedding_module
emb: torch.nn.Module
the embedding module
dec: torch.nn.Module
the decoder module
lin: torch.nn.Module
the linear module
out: torch.nn.Module
the output layer (typically log_softmax)
bos_token: int
the index of the Beginning-of-Sentence token
use_word_emb: bool
whether or not to use word embedding
word_emb_enc: nn.Module
a module to encode word embeddings
"""
def __init__(
self,
enc,
encoder_emb,
emb,
dec,
lin,
out,
bos_token=0,
use_word_emb=False,
word_emb_enc=None,
):
super().__init__()
self.enc = enc
self.encoder_emb = encoder_emb
self.emb = emb
self.dec = dec
self.lin = lin
self.out = out
self.bos_token = bos_token
self.use_word_emb = use_word_emb
self.word_emb_enc = word_emb_enc if use_word_emb else None
[docs]
def forward(self, grapheme_encoded, phn_encoded=None, word_emb=None):
"""Computes the forward pass
Arguments
---------
grapheme_encoded: torch.Tensor
graphemes encoded as a Torch tensor
phn_encoded: torch.Tensor
the encoded phonemes
word_emb: torch.Tensor
word embeddings (optional)
Returns
-------
p_seq: torch.Tensor
a (batch x position x token) tensor of token probabilities in each
position
char_lens: torch.Tensor
a tensor of character sequence lengths
encoder_out:
the raw output of the encoder
"""
chars, char_lens = grapheme_encoded
if phn_encoded is None:
phn_bos = get_dummy_phonemes(chars.size(0), chars.device)
else:
phn_bos, _ = phn_encoded
emb_char = self.encoder_emb(chars)
if self.use_word_emb:
emb_char = _apply_word_emb(self.word_emb_enc, emb_char, word_emb)
encoder_out, _ = self.enc(emb_char)
e_in = self.emb(phn_bos)
h, w = self.dec(e_in, encoder_out, char_lens)
logits = self.lin(h)
p_seq = self.out(logits)
return p_seq, char_lens, encoder_out, w
def _apply_word_emb(self, emb_char, word_emb):
"""Concatenate character embeddings with word embeddings,
possibly encoding the word embeddings if an encoder
is provided
Arguments
---------
emb_char: torch.Tensor
the character embedding tensor
word_emb: torch.Tensor
the word embedding tensor
Returns
-------
result: torch.Tensor
the concatenation of the tensor"""
word_emb_enc = (
self.word_emb_enc(word_emb)
if self.word_emb_enc is not None
else word_emb
)
return torch.cat([emb_char, word_emb_enc], dim=-1)
[docs]
class WordEmbeddingEncoder(nn.Module):
"""A small encoder module that reduces the dimensionality
and normalizes word embeddings
Arguments
---------
word_emb_dim: int
the dimension of the original word embeddings
word_emb_enc_dim: int
the dimension of the encoded word embeddings
norm: torch.nn.Module
the normalization to be used (
e.g. speechbrain.nnet.normalization.LayerNorm)
norm_type: str
the type of normalization to be used
"""
def __init__(
self, word_emb_dim, word_emb_enc_dim, norm=None, norm_type=None
):
super().__init__()
self.word_emb_dim = word_emb_dim
self.word_emb_enc_dim = word_emb_enc_dim
if norm_type:
self.norm = self._get_norm(norm_type, word_emb_dim)
else:
self.norm = norm
self.lin = Linear(n_neurons=word_emb_enc_dim, input_size=word_emb_dim)
self.activation = nn.Tanh()
def _get_norm(self, norm, dim):
"""Determines the type of normalizer
Arguments
---------
norm: str
the normalization type: "batch", "layer" or "instance
dim: int
the dimensionality of the inputs
Returns
-------
The normalized outputs.
"""
norm_cls = self.NORMS.get(norm)
if not norm_cls:
raise ValueError(f"Invalid norm: {norm}")
return norm_cls(input_size=dim)
[docs]
def forward(self, emb):
"""Computes the forward pass of the embedding
Arguments
---------
emb: torch.Tensor
the original word embeddings
Returns
-------
emb_enc: torch.Tensor
encoded word embeddings
"""
if self.norm is not None:
x = self.norm(emb)
x = self.lin(x)
x = self.activation(x)
return x
NORMS = {
"batch": normalization.BatchNorm1d,
"layer": normalization.LayerNorm,
"instance": normalization.InstanceNorm1d,
}
[docs]
class TransformerG2P(TransformerInterface):
"""
A Transformer-based Grapheme-to-Phoneme model
Arguments
----------
emb: torch.nn.Module
the embedding module
encoder_emb: torch.nn.Module
the encoder embedding module
char_lin: torch.nn.Module
a linear module connecting the inputs
to the transformer
phn_lin: torch.nn.Module
a linear module connecting the outputs to
the transformer
out: torch.nn.Module
the decoder module (usually Softmax)
lin: torch.nn.Module
the linear module for outputs
d_model: int
The number of expected features in the encoder/decoder inputs (default=512).
nhead: int
The number of heads in the multi-head attention models (default=8).
num_encoder_layers: int, optional
The number of encoder layers in1ì the encoder.
num_decoder_layers: int, optional
The number of decoder layers in the decoder.
dim_ffn: int, optional
The dimension of the feedforward network model hidden layer.
dropout: int, optional
The dropout value.
activation: torch.nn.Module, optional
The activation function for Feed-Forward Network layer,
e.g., relu or gelu or swish.
custom_src_module: torch.nn.Module, optional
Module that processes the src features to expected feature dim.
custom_tgt_module: torch.nn.Module, optional
Module that processes the src features to expected feature dim.
positional_encoding: str, optional
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
normalize_before: bool, optional
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
Defaults to True as this was shown to lead to better performance and training stability.
kernel_size: int, optional
Kernel size in convolutional layers when Conformer is used.
bias: bool, optional
Whether to use bias in Conformer convolutional layers.
encoder_module: str, optional
Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer.
conformer_activation: torch.nn.Module, optional
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
attention_type: str, optional
Type of attention layer used in all Transformer or Conformer layers.
e.g. regularMHA or RelPosMHA.
max_length: int, optional
Max length for the target and source sequence in input.
Used for positional encodings.
causal: bool, optional
Whether the encoder should be causal or not (the decoder is always causal).
If causal the Conformer convolutional layer is causal.
pad_idx: int
the padding index (for masks)
encoder_kdim: int, optional
Dimension of the key for the encoder.
encoder_vdim: int, optional
Dimension of the value for the encoder.
decoder_kdim: int, optional
Dimension of the key for the decoder.
decoder_vdim: int, optional
Dimension of the value for the decoder.
"""
def __init__(
self,
emb,
encoder_emb,
char_lin,
phn_lin,
lin,
out,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ffn=2048,
dropout=0.1,
activation=nn.ReLU,
custom_src_module=None,
custom_tgt_module=None,
positional_encoding="fixed_abs_sine",
normalize_before=True,
kernel_size=15,
bias=True,
encoder_module="transformer",
attention_type="regularMHA",
max_length=2500,
causal=False,
pad_idx=0,
encoder_kdim=None,
encoder_vdim=None,
decoder_kdim=None,
decoder_vdim=None,
use_word_emb=False,
word_emb_enc=None,
):
super().__init__(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ffn=d_ffn,
dropout=dropout,
activation=activation,
custom_src_module=custom_src_module,
custom_tgt_module=custom_tgt_module,
positional_encoding=positional_encoding,
normalize_before=normalize_before,
kernel_size=kernel_size,
bias=bias,
encoder_module=encoder_module,
attention_type=attention_type,
max_length=max_length,
causal=causal,
encoder_kdim=encoder_kdim,
encoder_vdim=encoder_vdim,
decoder_kdim=decoder_kdim,
decoder_vdim=decoder_vdim,
)
self.emb = emb
self.encoder_emb = encoder_emb
self.char_lin = char_lin
self.phn_lin = phn_lin
self.lin = lin
self.out = out
self.pad_idx = pad_idx
self.use_word_emb = use_word_emb
self.word_emb_enc = word_emb_enc
self._reset_params()
[docs]
def forward(self, grapheme_encoded, phn_encoded=None, word_emb=None):
"""Computes the forward pass
Arguments
---------
grapheme_encoded: torch.Tensor
graphemes encoded as a Torch tensor
phn_encoded: torch.Tensor
the encoded phonemes
word_emb: torch.Tensor
word embeddings (if applicable)
Returns
-------
p_seq: torch.Tensor
the log-probabilities of individual tokens i a sequence
char_lens: torch.Tensor
the character length syntax
encoder_out: torch.Tensor
the encoder state
attention: torch.Tensor
the attention state
"""
chars, char_lens = grapheme_encoded
if phn_encoded is None:
phn = get_dummy_phonemes(chars.size(0), chars.device)
else:
phn, _ = phn_encoded
emb_char = self.encoder_emb(chars)
if self.use_word_emb:
emb_char = _apply_word_emb(self.word_emb_enc, emb_char, word_emb)
src = self.char_lin(emb_char)
tgt = self.emb(phn)
tgt = self.phn_lin(tgt)
(
src_key_padding_mask,
tgt_key_padding_mask,
src_mask,
tgt_mask,
) = self.make_masks(src, tgt, char_lens, pad_idx=self.pad_idx)
pos_embs_encoder = None
if self.attention_type == "RelPosMHAXL":
pos_embs_encoder = self.positional_encoding(src)
elif self.positional_encoding_type == "fixed_abs_sine":
src = src + self.positional_encoding(src) # add the encodings here
pos_embs_encoder = None
encoder_out, _ = self.encoder(
src=src,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs_encoder,
)
if self.attention_type == "RelPosMHAXL":
# use standard sinusoidal pos encoding in decoder
tgt = tgt + self.positional_encoding_decoder(tgt)
src = src + self.positional_encoding_decoder(src)
pos_embs_encoder = None
pos_embs_target = None
elif self.positional_encoding_type == "fixed_abs_sine":
tgt = tgt + self.positional_encoding(tgt)
pos_embs_target = None
pos_embs_encoder = None
decoder_out, _, attention = self.decoder(
tgt=tgt,
memory=encoder_out,
memory_mask=src_mask,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
pos_embs_tgt=pos_embs_target,
pos_embs_src=pos_embs_encoder,
)
logits = self.lin(decoder_out)
p_seq = self.out(logits)
return p_seq, char_lens, encoder_out, attention
def _reset_params(self):
"""Resets the parameters of the model"""
for p in self.parameters():
if p.dim() > 1:
torch.nn.init.xavier_normal_(p)
[docs]
def make_masks(self, src, tgt, src_len=None, pad_idx=0):
"""This method generates the masks for training the transformer model.
Arguments
---------
src : torch.Tensor
The sequence to the encoder (required).
tgt : torch.Tensor
The sequence to the decoder (required).
src_len : torch.Tensor
Lengths corresponding to the src tensor.
pad_idx : int
The index for <pad> token (default=0).
Returns
-------
src_key_padding_mask: torch.Tensor
the source key padding mask
tgt_key_padding_mask: torch.Tensor
the target key padding masks
src_mask: torch.Tensor
the source mask
tgt_mask: torch.Tensor
the target mask
"""
if src_len is not None:
abs_len = torch.round(src_len * src.shape[1])
src_key_padding_mask = (
torch.arange(src.shape[1])[None, :].to(abs_len)
> abs_len[:, None]
)
tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx)
src_mask = None
tgt_mask = get_lookahead_mask(tgt)
return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
[docs]
def decode(self, tgt, encoder_out, enc_lens):
"""This method implements a decoding step for the transformer model.
Arguments
---------
tgt : torch.Tensor
The sequence to the decoder.
encoder_out : torch.Tensor
Hidden output of the encoder.
enc_lens : torch.Tensor
The corresponding lengths of the encoder inputs.
Returns
-------
prediction: torch.Tensor
the predicted sequence
attention: torch.Tensor
the attention matrix corresponding to the last attention head
(useful for plotting attention)
"""
tgt_mask = get_lookahead_mask(tgt)
tgt = self.emb(tgt)
tgt = self.phn_lin(tgt)
if self.attention_type == "RelPosMHAXL":
# we use fixed positional encodings in the decoder
tgt = tgt + self.positional_encoding_decoder(tgt)
encoder_out = encoder_out + self.positional_encoding_decoder(
encoder_out
)
elif self.positional_encoding_type == "fixed_abs_sine":
tgt = tgt + self.positional_encoding(tgt) # add the encodings here
prediction, self_attns, multihead_attns = self.decoder(
tgt,
encoder_out,
tgt_mask=tgt_mask,
pos_embs_tgt=None,
pos_embs_src=None,
)
attention = multihead_attns[-1]
return prediction, attention
[docs]
def input_dim(use_word_emb, embedding_dim, word_emb_enc_dim):
"""Computes the input dimension (intended for hparam files)
Arguments
---------
use_word_emb: bool
whether to use word embeddings
embedding_dim: int
the embedding dimension
word_emb_enc_dim: int
the dimension of encoded word embeddings
Returns
-------
input_dim: int
the input dimension
"""
return embedding_dim + use_word_emb * word_emb_enc_dim
def _apply_word_emb(word_emb_enc, emb_char, word_emb):
"""
Concatenates character and word embeddings together, possibly
applying a custom encoding/transformation
Arguments
---------
word_emb_enc: callable
an encoder to apply (typically, speechbrain.lobes.models.g2p.model.WordEmbeddingEncoder)
emb_char: torch.Tensor
character embeddings
word_emb: char
word embeddings
Returns
-------
result: torch.Tensor
the resulting (concatenated) tensor
"""
word_emb_enc = (
word_emb_enc(word_emb.data)
if word_emb_enc is not None
else word_emb.data
)
return torch.cat([emb_char, word_emb_enc], dim=-1)
[docs]
def get_dummy_phonemes(batch_size, device):
"""
Creates a dummy phoneme sequence
Arguments
---------
batch_size: int
the batch size
device: str
the target device
Returns
-------
result: torch.Tensor
"""
return torch.tensor([0], device=device).expand(batch_size, 1)