speechbrain.lobes.models.GatedNN moduleο
Gated Neural Network variant of VanillaNN for simple feed-forward tests.
Summaryο
Classes:
A simple stacked Gated Neural Network for feed-forward modeling. |
|
Single gated feed-forward block used in |
Referenceο
- class speechbrain.lobes.models.GatedNN.GatedNNBlock(n_neurons, input_shape=None, input_size=None, activation=<class 'torch.nn.modules.activation.GELU'>, bias=False, combine_dims=False)[source]ο
Bases:
ModuleSingle gated feed-forward block used in
GatedNN.This block applies two parallel linear projections to the input and combines them with an element-wise product after passing one branch through a non-linear activation. A final linear layer projects the gated representation back to the original input dimensionality.
- Parameters:
n_neurons (int) β Number of neurons in the hidden (gated) representation.
input_shape (tuple or None) β Shape of the input tensor. Used to infer
input_sizewhen not given.input_size (int or None) β Flattened size of the last (or spatially combined) input dimension. One of
input_shapeorinput_sizemust be provided.activation (torch.nn.Module or callable) β Activation class used in the gated branch (default:
torch.nn.GELU).bias (bool) β If True, use bias terms in the linear layers.
combine_dims (bool) β If True and the input is 4D, combines the last two dimensions before applying the linear layers.
- forward(x)[source]ο
Returns the output of the GatedNNBlock.
- Parameters:
x (torch.Tensor) β Input tensor.
- Returns:
x β Output tensor.
- Return type:
- class speechbrain.lobes.models.GatedNN.GatedNN(input_shape, activation=<class 'torch.nn.modules.activation.GELU'>, blocks=2, neurons=512, bias=False, combine_dims=False)[source]ο
Bases:
SequentialA simple stacked Gated Neural Network for feed-forward modeling.
This model stacks multiple
GatedNNBlockmodules on top of each other, keeping the same input and output dimensionality while increasing representational power through gated non-linear transformations.- Parameters:
input_shape (tuple) β Expected shape of the input tensors.
activation (torch.nn.Module or callable) β Activation class used inside each gated block (default:
torch.nn.GELU).blocks (int) β Number of stacked gated blocks.
neurons (int) β Number of neurons in the hidden (gated) representation of each block.
bias (bool) β If True, use bias terms in the linear layers.
combine_dims (bool) β If True and the input is 4D, combines the last two dimensions before applying the linear layers in each block.
Example
>>> inputs = torch.rand([10, 120, 60]) >>> model = GatedNN(input_shape=inputs.shape, blocks=2, neurons=512) >>> outputs = model(inputs) >>> outputs.shape torch.Size([10, 120, 60])