speechbrain.nnet.linear module
Library implementing linear transformation.
- Authors
Mirco Ravanelli 2020
Davide Borra 2021
Summary
Classes:
Computes a linear transformation y = wx + b. |
|
Computes a linear transformation y = wx + b with kernel max-norm constaint. |
Reference
- class speechbrain.nnet.linear.Linear(n_neurons, input_shape=None, input_size=None, bias=True, combine_dims=False)[source]
Bases:
Module
Computes a linear transformation y = wx + b.
- Parameters:
n_neurons (int) – It is the number of output neurons (i.e, the dimensionality of the output).
input_shape (tuple) – It is the shape of the input tensor.
input_size (int) – Size of the input tensor.
bias (bool) – If True, the additive bias b is adopted.
combine_dims (bool) – If True and the input is 4D, combine 3rd and 4th dimensions of input.
Example
>>> inputs = torch.rand(10, 50, 40) >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100) >>> output = lin_t(inputs) >>> output.shape torch.Size([10, 50, 100])
- forward(x)[source]
Returns the linear transformation of input tensor.
- Parameters:
x (torch.Tensor) – Input to transform linearly.
- class speechbrain.nnet.linear.LinearWithConstraint(*args, max_norm=1, **kwargs)[source]
Bases:
Linear
Computes a linear transformation y = wx + b with kernel max-norm constaint. This corresponds to set an upper bound for the kernel norm.
- Parameters:
n_neurons (int) – It is the number of output neurons (i.e, the dimensionality of the output).
input_shape (tuple) – It is the shape of the input tensor.
input_size (int) – Size of the input tensor.
bias (bool) – If True, the additive bias b is adopted.
combine_dims (bool) – If True and the input is 4D, combine 3rd and 4th dimensions of input.
max_norm (float) – Kernel max-norm
Example
>>> inputs = torch.rand(100,) >>> max_norm = 1. >>> lin_t_contrained = LinearWithConstraint(input_size=inputs.shape[0], n_neurons=2, max_norm=max_norm) >>> output = lin_t_contrained(inputs) >>> torch.any(torch.norm(lin_t_contrained.w.weight.data, p=2, dim=0)>max_norm) tensor(False)
- forward(x)[source]
Returns the linear transformation of input tensor.
- Parameters:
x (torch.Tensor) – Input to transform linearly.