Source code for speechbrain.nnet.activations

"""Library implementing activation functions.

Authors
 * Mirco Ravanelli 2020
 * Jianyuan Zhong 2020
"""

import torch
import torch.nn.functional as F

from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


[docs] class Softmax(torch.nn.Module): """Computes the softmax of a 2d, 3d, or 4d input tensor. Arguments --------- apply_log : bool Whether to apply the log function before softmax. dim : int If the dimension where softmax is applied. reshape: bool whether to apply reshaping (true by default) dtype: torch.dtype dtype of the output tensor Example ------- >>> classifier = Softmax() >>> inputs = torch.rand(10, 50, 40) >>> output = classifier(inputs) >>> output.shape torch.Size([10, 50, 40]) """ def __init__( self, apply_log=False, dim=-1, reshape=True, dtype=torch.float32 ): super().__init__() if apply_log: self.act = F.log_softmax else: self.act = F.softmax self.dim = dim self.reshape = reshape self.dtype = dtype
[docs] def forward(self, x): """Returns the softmax of the input tensor. Arguments --------- x : torch.Tensor Input tensor. Returns ------- x_act : torch.Tensor The softmax outputs. """ # Reshaping the tensors dims = x.shape if self.reshape: if len(dims) == 3: x = x.reshape(dims[0] * dims[1], dims[2]) if len(dims) == 4: x = x.reshape(dims[0] * dims[1], dims[2], dims[3]) x_act = self.act(x, dim=self.dim, dtype=self.dtype) # Retrieving the original shape format if self.reshape: if len(dims) == 3: x_act = x_act.reshape(dims[0], dims[1], dims[2]) if len(dims) == 4: x_act = x_act.reshape(dims[0], dims[1], dims[2], dims[3]) return x_act
[docs] class GumbelSoftmax(torch.nn.Module): """Samples from the Gumbel-Softmax distribution and optionally discretizes. Reference: https://arxiv.org/abs/1611.00712, https://arxiv.org/abs/1611.01144 Arguments --------- tau: float non-negative scalar temperature hard: bool if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd apply_log: bool if True, returns the log of the softmax outputs. Example ------- >>> x = torch.randn((8, 40, 120)) >>> act = GumbelSoftmax(0.8, True) >>> x = act(x) """ def __init__(self, tau, hard=False, apply_log=False): super().__init__() self.tau = tau self.hard = hard self.apply_log = apply_log
[docs] def forward(self, x): """Returns the Gumbel softmax of the input tensor. Arguments --------- x : torch.Tensor Input tensor. Returns ------- The Gumbel softmax output. """ if self.apply_log: return torch.log(F.gumbel_softmax(x, tau=self.tau, hard=self.hard)) return F.gumbel_softmax(x, tau=self.tau, hard=self.hard)
[docs] class Swish(torch.nn.Module): """The class implements the Swish activation function from https://arxiv.org/pdf/2005.03191.pdf given input x. Swish(x) = x / (1 + exp(beta * x)) Arguments --------- beta: float Beta value. Example ------- >>> x = torch.randn((8, 40, 120)) >>> act = Swish() >>> x = act(x) """ def __init__(self, beta: float = 1.0): super().__init__() self.beta = beta self.silu = torch.nn.SiLU()
[docs] def forward(self, x): """Returns the Swished input tensor. Arguments --------- x : torch.Tensor Input tensor. Returns ------- The swished output. """ if self.beta != 1: # slow path x = x * self.beta return self.silu(x)