speechbrain.nnet.activations module

Library implementing activation functions.

Authors
  • Mirco Ravanelli 2020

  • Jianyuan Zhong 2020

Summary

Classes:

GumbelSoftmax

Samples from the Gumbel-Softmax distribution and optionally discretizes.

Softmax

Computes the softmax of a 2d, 3d, or 4d input tensor.

Swish

The class implements the Swish activation function from https://arxiv.org/pdf/2005.03191.pdf

Reference

class speechbrain.nnet.activations.Softmax(apply_log=False, dim=-1)[source]

Bases: Module

Computes the softmax of a 2d, 3d, or 4d input tensor.

Parameters:
  • apply_log (bool) – Whether to apply the log function before softmax.

  • dim (int) – If the dimension where softmax is applied.

Example

>>> classifier = Softmax()
>>> inputs = torch.rand(10, 50, 40)
>>> output = classifier(inputs)
>>> output.shape
torch.Size([10, 50, 40])
forward(x)[source]

Returns the softmax of the input tensor.

Parameters:

x (torch.Tensor) – Input tensor.

training: bool
class speechbrain.nnet.activations.GumbelSoftmax(tau, hard=False, apply_log=False)[source]

Bases: Module

Samples from the Gumbel-Softmax distribution and optionally discretizes.

Reference: https://arxiv.org/abs/1611.00712, https://arxiv.org/abs/1611.01144

Parameters:
  • tau (float) – non-negative scalar temperature

  • hard (bool) – if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd

  • dim (int) – A dimension along which softmax will be computed (default: -1).

Example

>>> x = torch.randn((8, 40, 120))
>>> act = GumbelSoftmax(0.8, True)
>>> x = act(x)
forward(x)[source]

Returns the Gumbel softmax of the input tensor.

Parameters:

x (torch.Tensor) – Input tensor.

training: bool
class speechbrain.nnet.activations.Swish(beta=1)[source]

Bases: Module

The class implements the Swish activation function from https://arxiv.org/pdf/2005.03191.pdf

given input x. Swish(x) = x / (1 + exp(beta * x))

Parameters:

beta (float) – Beta value.

Example

>>> x = torch.randn((8, 40, 120))
>>> act = Swish()
>>> x = act(x)
forward(x)[source]

Returns the Swished input tensor.

Parameters:

x (torch.Tensor) – Input tensor.

training: bool