Source code for speechbrain.nnet.dropout

"""Library implementing dropout.

Authors
 * Mirco Ravanelli 2020
"""
import torch  # noqa: F401
import logging
import torch.nn as nn

logger = logging.getLogger(__name__)


[docs] class Dropout2d(nn.Module): """This function implements dropout 2d. It randomly put zeros on entire channels. Arguments --------- dropout_rate : float It is the dropout factor (between 0 and 1). inplace : bool If True, it uses inplace operations. Example ------- >>> drop = Dropout2d(drop_rate=0.5) >>> inputs = torch.rand(10, 50, 40) >>> output=drop(inputs) >>> output.shape torch.Size([10, 50, 40]) """ def __init__( self, drop_rate, inplace=False, ): super().__init__() self.drop_rate = drop_rate self.inplace = inplace self.drop = nn.Dropout2d(p=self.drop_rate, inplace=self.inplace)
[docs] def forward(self, x): """Applies dropout 2d to the input tensor. Arguments --------- x : torch.Tensor (batch, time, channel1, channel2) input to normalize. 4d tensors are expected. """ # time must be the last x = x.transpose(1, 2).transpose(2, -1) x_drop = self.drop(x) x_drop = x_drop.transpose(-1, 1).transpose(2, -1) return x_drop