"""Library implementing dropout.
Authors
* Mirco Ravanelli 2020
"""
import torch # noqa: F401
import torch.nn as nn
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
[docs]
class Dropout2d(nn.Module):
"""This function implements dropout 2d. It randomly put zeros on
entire channels.
Arguments
---------
drop_rate : float
It is the dropout factor (between 0 and 1).
inplace : bool
If True, it uses inplace operations.
Example
-------
>>> drop = Dropout2d(drop_rate=0.5)
>>> inputs = torch.rand(10, 50, 40)
>>> output=drop(inputs)
>>> output.shape
torch.Size([10, 50, 40])
"""
def __init__(self, drop_rate, inplace=False):
super().__init__()
self.drop_rate = drop_rate
self.inplace = inplace
self.drop = nn.Dropout2d(p=self.drop_rate, inplace=self.inplace)
[docs]
def forward(self, x):
"""Applies dropout 2d to the input tensor.
Arguments
---------
x : torch.Tensor (batch, time, channel1, channel2)
input to normalize. 4d tensors are expected.
Returns
-------
x_drop : torch.Tensor
The tensor with channels zeroed out.
"""
# time must be the last
x = x.transpose(1, 2).transpose(2, -1)
x_drop = self.drop(x)
x_drop = x_drop.transpose(-1, 1).transpose(2, -1)
return x_drop