Source code for speechbrain.nnet.quaternion_networks.q_normalization

"""Library implementing quaternion-valued normalization.

Authors
 * Titouan Parcollet 2020
"""

import torch
from torch.nn import Parameter


[docs] class QBatchNorm(torch.nn.Module): """This class implements the simplest form of a quaternion batchnorm as described in : "Quaternion Convolutional Neural Network for Color Image Classification and Forensics", Qilin Y. et al. Arguments --------- input_size : int Expected size of the dimension to be normalized. dim : int, optional It defines the axis that should be normalized. It usually correspond to the channel dimension (default -1). gamma_init : float, optional First value of gamma to be used (mean) (default 1.0). beta_param : bool, optional When set to True the beta parameter of the BN is applied (default True). momentum : float, optional It defines the momentum as for the real-valued batch-normalization (default 0.1). eps : float, optional Term used to stabilize operation (default 1e-4). track_running_stats : bool, optional Equivalent to the real-valued batchnormalization parameter. When True, stats are tracked. When False, solely statistics computed over the batch are used (default True). Example ------- >>> inp_tensor = torch.rand([10, 40]) >>> QBN = QBatchNorm(input_size=40) >>> out_tensor = QBN(inp_tensor) >>> out_tensor.shape torch.Size([10, 40]) """ def __init__( self, input_size, dim=-1, gamma_init=1.0, beta_param=True, momentum=0.1, eps=1e-4, track_running_stats=True, ): super(QBatchNorm, self).__init__() self.num_features = input_size // 4 self.gamma_init = gamma_init self.beta_param = beta_param self.momentum = momentum self.dim = dim self.eps = eps self.track_running_stats = track_running_stats self.gamma = Parameter(torch.full([self.num_features], self.gamma_init)) self.beta = Parameter( torch.zeros(self.num_features * 4), requires_grad=self.beta_param ) # instantiate moving statistics if track_running_stats: self.register_buffer( "running_mean", torch.zeros(self.num_features * 4) ) self.register_buffer("running_var", torch.ones(self.num_features)) self.register_buffer( "num_batches_tracked", torch.tensor(0, dtype=torch.long) ) else: self.register_parameter("running_mean", None) self.register_parameter("running_var", None) self.register_parameter("num_batches_tracked", None)
[docs] def forward(self, input): """Returns the normalized input tensor. Arguments --------- input : torch.Tensor (batch, time, [channels]) Input to normalize. It can be 2d, 3d, 4d. """ exponential_average_factor = 0.0 # Entering training mode if self.training: if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = ( 1.0 / self.num_batches_tracked.item() ) else: # use exponential moving average exponential_average_factor = self.momentum # Get mean along batch axis mu = torch.mean(input, dim=0) mu_r, mu_i, mu_j, mu_k = torch.chunk(mu, 4, dim=self.dim) # Get variance along batch axis delta = input - mu delta_r, delta_i, delta_j, delta_k = torch.chunk( delta, 4, dim=self.dim ) quat_variance = torch.mean( (delta_r ** 2 + delta_i ** 2 + delta_j ** 2 + delta_k ** 2), dim=0, ) denominator = torch.sqrt(quat_variance + self.eps) # x - mu / sqrt(var + e) out = input / torch.cat( [denominator, denominator, denominator, denominator], dim=self.dim, ) # Update the running stats if self.track_running_stats: self.running_mean = ( 1 - exponential_average_factor ) * self.running_mean + exponential_average_factor * mu.view( self.running_mean.size() ) self.running_var = ( 1 - exponential_average_factor ) * self.running_var + exponential_average_factor * quat_variance.view( self.running_var.size() ) else: q_var = torch.cat( [ self.running_var, self.running_var, self.running_var, self.running_var, ], dim=self.dim, ) out = (input - self.running_mean) / q_var # lambda * (x - mu / sqrt(var + e)) + beta q_gamma = torch.cat( [self.gamma, self.gamma, self.gamma, self.gamma], dim=self.dim ) out = (q_gamma * out) + self.beta return out