speechbrain.lobes.models.GatedNN module

Gated Neural Network variant of VanillaNN for simple feed-forward tests.

Authors

  • Adel Moumen 2025

Summary

Classes:

GatedNN

A simple stacked Gated Neural Network for feed-forward modeling.

GatedNNBlock

Single gated feed-forward block used in GatedNN.

Reference

class speechbrain.lobes.models.GatedNN.GatedNNBlock(n_neurons, input_shape=None, input_size=None, activation=<class 'torch.nn.modules.activation.GELU'>, bias=False, combine_dims=False)[source]

Bases: Module

Single gated feed-forward block used in GatedNN.

This block applies two parallel linear projections to the input and combines them with an element-wise product after passing one branch through a non-linear activation. A final linear layer projects the gated representation back to the original input dimensionality.

Parameters:
  • n_neurons (int) – Number of neurons in the hidden (gated) representation.

  • input_shape (tuple or None) – Shape of the input tensor. Used to infer input_size when not given.

  • input_size (int or None) – Flattened size of the last (or spatially combined) input dimension. One of input_shape or input_size must be provided.

  • activation (torch.nn.Module or callable) – Activation class used in the gated branch (default: torch.nn.GELU).

  • bias (bool) – If True, use bias terms in the linear layers.

  • combine_dims (bool) – If True and the input is 4D, combines the last two dimensions before applying the linear layers.

forward(x)[source]

Returns the output of the GatedNNBlock.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

x – Output tensor.

Return type:

torch.Tensor

class speechbrain.lobes.models.GatedNN.GatedNN(input_shape, activation=<class 'torch.nn.modules.activation.GELU'>, blocks=2, neurons=512, bias=False, combine_dims=False)[source]

Bases: Sequential

A simple stacked Gated Neural Network for feed-forward modeling.

This model stacks multiple GatedNNBlock modules on top of each other, keeping the same input and output dimensionality while increasing representational power through gated non-linear transformations.

Parameters:
  • input_shape (tuple) – Expected shape of the input tensors.

  • activation (torch.nn.Module or callable) – Activation class used inside each gated block (default: torch.nn.GELU).

  • blocks (int) – Number of stacked gated blocks.

  • neurons (int) – Number of neurons in the hidden (gated) representation of each block.

  • bias (bool) – If True, use bias terms in the linear layers.

  • combine_dims (bool) – If True and the input is 4D, combines the last two dimensions before applying the linear layers in each block.

Example

>>> inputs = torch.rand([10, 120, 60])
>>> model = GatedNN(input_shape=inputs.shape, blocks=2, neurons=512)
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([10, 120, 60])