Source code for speechbrain.lobes.models.transformer.TransformerSE

"""CNN Transformer model for SE in the SpeechBrain style.

* Chien-Feng Liao 2020
import torch  # noqa E402
from torch import nn
from speechbrain.nnet.linear import Linear
from speechbrain.lobes.models.transformer.Transformer import (

[docs] class CNNTransformerSE(TransformerInterface): """This is an implementation of transformer model with CNN pre-encoder for SE. Arguments --------- d_model : int The number of expected features in the encoder inputs. output_size : int The number of neurons in the output layer. output_activation : torch class The activation function of the output layer (default=ReLU). nhead : int The number of heads in the multi-head attention models (default=8). num_layers : int The number of sub-layers in the transformer (default=8). d_ffn : int The number of expected features in the encoder layers (default=512). dropout : int The dropout value (default=0.1). activation : torch class The activation function of intermediate layers (default=LeakyReLU). causal : bool True for causal setting, the model is forbidden to see future frames (default=True). custom_emb_module : torch class Module that processes the input features before the transformer model. Example ------- >>> src = torch.rand([8, 120, 256]) >>> net = CNNTransformerSE(d_model=256, output_size=257) >>> out = net(src) >>> out.shape torch.Size([8, 120, 257]) """ def __init__( self, d_model, output_size, output_activation=nn.ReLU, nhead=8, num_layers=8, d_ffn=512, dropout=0.1, activation=nn.LeakyReLU, causal=True, custom_emb_module=None, normalize_before=False, ): super().__init__( d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=0, d_ffn=d_ffn, dropout=dropout, activation=activation, positional_encoding=None, normalize_before=normalize_before, causal=causal, ) self.custom_emb_module = custom_emb_module self.output_layer = Linear(output_size, input_size=d_model, bias=False) self.output_activation = output_activation()
[docs] def forward(self, x, src_key_padding_mask=None): """ Processes the input tensor x and returns an output tensor.""" if self.causal: self.attn_mask = get_lookahead_mask(x) else: self.attn_mask = None if self.custom_emb_module is not None: x = self.custom_emb_module(x) encoder_output, _ = self.encoder( src=x, src_mask=self.attn_mask, src_key_padding_mask=src_key_padding_mask, ) output = self.output_layer(encoder_output) output = self.output_activation(output) return output