"""Library implementing activation functions.
Authors
* Mirco Ravanelli 2020
* Jianyuan Zhong 2020
"""
import torch
import logging
import torch.nn.functional as F
logger = logging.getLogger(__name__)
[docs]class Softmax(torch.nn.Module):
"""Computes the softmax of a 2d, 3d, or 4d input tensor.
Arguments
---------
apply_log : bool
Whether to apply the log function before softmax.
dim : int
If the dimension where softmax is applied.
reshape: bool
whether to apply reshaping (true by default)
Example
-------
>>> classifier = Softmax()
>>> inputs = torch.rand(10, 50, 40)
>>> output = classifier(inputs)
>>> output.shape
torch.Size([10, 50, 40])
"""
def __init__(self, apply_log=False, dim=-1, reshape=True):
super().__init__()
if apply_log:
self.act = torch.nn.LogSoftmax(dim=dim)
else:
self.act = torch.nn.Softmax(dim=dim)
self.reshape = reshape
[docs] def forward(self, x):
"""Returns the softmax of the input tensor.
Arguments
---------
x : torch.Tensor
Input tensor.
"""
# Reshaping the tensors
dims = x.shape
if self.reshape:
if len(dims) == 3:
x = x.reshape(dims[0] * dims[1], dims[2])
if len(dims) == 4:
x = x.reshape(dims[0] * dims[1], dims[2], dims[3])
x_act = self.act(x)
# Retrieving the original shape format
if self.reshape:
if len(dims) == 3:
x_act = x_act.reshape(dims[0], dims[1], dims[2])
if len(dims) == 4:
x_act = x_act.reshape(dims[0], dims[1], dims[2], dims[3])
return x_act
[docs]class GumbelSoftmax(torch.nn.Module):
"""Samples from the Gumbel-Softmax distribution and optionally discretizes.
Reference: https://arxiv.org/abs/1611.00712, https://arxiv.org/abs/1611.01144
Arguments
----------
tau: float
non-negative scalar temperature
hard: bool
if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
dim: int
A dimension along which softmax will be computed (default: -1).
Example
-------
>>> x = torch.randn((8, 40, 120))
>>> act = GumbelSoftmax(0.8, True)
>>> x = act(x)
"""
def __init__(self, tau, hard=False, apply_log=False):
super().__init__()
self.tau = tau
self.hard = hard
self.apply_log = apply_log
[docs] def forward(self, x):
"""Returns the Gumbel softmax of the input tensor.
Arguments
---------
x : torch.Tensor
Input tensor.
"""
if self.apply_log:
return torch.log(F.gumbel_softmax(x, tau=self.tau, hard=self.hard))
return F.gumbel_softmax(x, tau=self.tau, hard=self.hard)
[docs]class Swish(torch.nn.Module):
""" The class implements the Swish activation function from
https://arxiv.org/pdf/2005.03191.pdf
given input x. Swish(x) = x / (1 + exp(beta * x))
Arguments
---------
beta: float
Beta value.
Example
-------
>>> x = torch.randn((8, 40, 120))
>>> act = Swish()
>>> x = act(x)
"""
def __init__(self, beta=1):
super().__init__()
self.beta = beta
self.sigmoid = torch.nn.Sigmoid()
[docs] def forward(self, x):
"""Returns the Swished input tensor.
Arguments
---------
x : torch.Tensor
Input tensor.
"""
return x * self.sigmoid(self.beta * x)