Source code for speechbrain.nnet.linear

"""Library implementing linear transformation.

Authors
 * Mirco Ravanelli 2020
 * Davide Borra 2021
"""

import torch
import logging
import torch.nn as nn

logger = logging.getLogger(__name__)


[docs] class Linear(torch.nn.Module): """Computes a linear transformation y = wx + b. Arguments --------- n_neurons : int It is the number of output neurons (i.e, the dimensionality of the output). input_shape: tuple It is the shape of the input tensor. input_size: int Size of the input tensor. bias : bool If True, the additive bias b is adopted. combine_dims : bool If True and the input is 4D, combine 3rd and 4th dimensions of input. max_norm: float weight max-norm. Example ------- >>> inputs = torch.rand(10, 50, 40) >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100) >>> output = lin_t(inputs) >>> output.shape torch.Size([10, 50, 100]) """ def __init__( self, n_neurons, input_shape=None, input_size=None, bias=True, max_norm=None, combine_dims=False, ): super().__init__() self.max_norm = max_norm self.combine_dims = combine_dims if input_shape is None and input_size is None: raise ValueError("Expected one of input_shape or input_size") if input_size is None: input_size = input_shape[-1] if len(input_shape) == 4 and self.combine_dims: input_size = input_shape[2] * input_shape[3] # Weights are initialized following pytorch approach self.w = nn.Linear(input_size, n_neurons, bias=bias)
[docs] def forward(self, x): """Returns the linear transformation of input tensor. Arguments --------- x : torch.Tensor Input to transform linearly. """ if x.ndim == 4 and self.combine_dims: x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) if self.max_norm is not None: self.w.weight.data = torch.renorm( self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm ) wx = self.w(x) return wx