"""Library implementing complex-valued normalization.
Authors
* Titouan Parcollet 2020
"""
import torch
from torch.nn import Parameter
import numpy as np
from speechbrain.nnet.complex_networks.c_ops import multi_mean
[docs]class CBatchNorm(torch.nn.Module):
"""This class is implements the complex-valued batch-normalization
as introduced by "Deep Complex Networks", Trabelsi C. et al.
Arguments
---------
input_shape : tuple
Expected shape of the input.
input_size : int
Expected size of the input.
dim : int, optional
It defines the axis that should be normalized. It usually correspond to
the channel dimension (default -1).
eps : float, optional
Term used to stabilize operation (default 1e-4).
momentum : float, optional
It defines the momentum as for the real-valued batch-normalization
(default 0.1).
scale : bool, optional,
It defines if scaling should be used or not. It is
equivalent to the real-valued batchnormalization scaling (default True).
center : bool, optional
It defines if centering should be used or not. It is
equivalent to the real-valued batchnormalization centering
(default True).
track_running_stats : bool, optional
Equivalent to the real-valued batchnormalization parameter.
When True, stats are tracked. When False, solely statistics computed
over the batch are used (default True).
Example
-------
>>> inp_tensor = torch.rand([10, 16, 30])
>>> CBN = CBatchNorm(input_shape=inp_tensor.shape)
>>> out_tensor = CBN(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 16, 30])
"""
def __init__(
self,
input_shape=None,
input_size=None,
dim=-1,
eps=1e-4,
momentum=0.1,
scale=True,
center=True,
track_running_stats=True,
):
super().__init__()
self.dim = dim
self.eps = eps
self.momentum = momentum
self.scale = scale
self.center = center
self.track_running_stats = track_running_stats
if input_size is None:
self.num_complex_features = self._check_input(input_shape)
else:
self.num_complex_features = input_size // 2
if self.scale:
self.gamma_rr = Parameter(torch.empty(self.num_complex_features))
self.gamma_ii = Parameter(torch.empty(self.num_complex_features))
self.gamma_ri = Parameter(torch.empty(self.num_complex_features))
else:
self.register_parameter("gamma_rr", None)
self.register_parameter("gamma_ii", None)
self.register_parameter("gamma_ri", None)
if self.center:
self.beta = Parameter(torch.empty(self.num_complex_features * 2))
else:
self.register_parameter("beta", None)
if self.track_running_stats:
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
if self.scale:
# We initializing the scaling parameter following the proposal
# of "Deep Complex Networks". Trabelsi C. et al.
self.register_buffer(
"moving_Vrr",
torch.ones(self.num_complex_features) * np.sqrt(1 / 2),
)
self.register_buffer(
"moving_Vii",
torch.ones(self.num_complex_features) * np.sqrt(1 / 2),
)
self.register_buffer(
"moving_Vri", torch.zeros(self.num_complex_features)
)
else:
self.register_parameter("moving_Vrr", None)
self.register_parameter("moving_Vii", None)
self.register_parameter("moving_Vri", None)
if self.center:
self.register_buffer(
"moving_mean", torch.zeros(self.num_complex_features * 2)
)
else:
self.register_parameter("moving_mean", None)
else:
self.register_parameter("moving_Vrr", None)
self.register_parameter("moving_Vii", None)
self.register_parameter("moving_Vri", None)
self.register_parameter("moving_mean", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
[docs] def reset_running_stats(self):
# Simply reset the running statistics to the initial values
# "Deep Complex Networks" Trabelsi C. et al.
if self.track_running_stats:
if self.center:
self.moving_mean.zero_()
if self.scale:
self.moving_Vrr.fill_(1 / np.sqrt(2))
self.moving_Vii.fill_(1 / np.sqrt(2))
self.moving_Vri.zero_()
self.num_batches_tracked.zero_()
[docs] def reset_parameters(self):
# Simply reset all the parameters.
# "Deep Complex Networks" Trabelsi C. et al.
self.reset_running_stats()
if self.scale:
self.gamma_rr.data.fill_(1 / np.sqrt(2))
self.gamma_ii.data.fill_(1 / np.sqrt(2))
self.gamma_ri.data.zero_()
if self.center:
self.beta.data.zero_()
[docs] def forward(self, input):
"""Returns the normalized input tensor.
Arguments
---------
input : torch.Tensor (batch, time, [channels])
Input to normalize. It can be 2d, 3d, 4d.
"""
exponential_average_factor = 0.0
# Initialize moving parameters
if self.training and self.track_running_stats:
if self.center:
self.moving_mean = self.moving_mean.detach()
if self.scale:
self.moving_Vrr = self.moving_Vrr.detach()
self.moving_Vii = self.moving_Vii.detach()
self.moving_Vri = self.moving_Vri.detach()
self.num_batches_tracked = self.num_batches_tracked.detach()
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
input_shape = input.size()
ndim = input.dim()
reduction_axes = list(range(ndim))
del reduction_axes[self.dim]
input_dim = input_shape[self.dim] // 2
# Get the mean and center the input
mu = multi_mean(input, reduction_axes, True)
input_centred = input - mu
if self.scale:
centred_squared = input_centred ** 2
# Retrieve the real and image parts of the input tensor w.r.t the
# dimension
if self.scale:
(
centred_squared_real,
centred_squared_imag,
) = self._retrieve_real_imag(centred_squared, ndim, input_dim)
if self.center:
centred_real, centred_imag = self._retrieve_real_imag(
input_centred, ndim, input_dim
)
# We compute the mean for each component
if self.scale:
Vrr = (
multi_mean(
centred_squared_real, axes=reduction_axes, keepdim=True
)
+ self.eps
)
Vii = (
multi_mean(
centred_squared_imag, axes=reduction_axes, keepdim=True
)
+ self.eps
)
# Vri contains the real and imaginary covariance
# for each feature map.
Vri = multi_mean(
centred_real * centred_imag, axes=reduction_axes, keepdim=True
)
else:
Vrr = None
Vii = None
Vri = None
# Pick the normalized form corresponding
# to the training phase when we use running stats.
if self.training and self.track_running_stats:
if self.center:
self.moving_mean = (
1 - exponential_average_factor
) * self.moving_mean + exponential_average_factor * mu.view(
self.moving_mean.size()
)
if self.scale:
self.moving_Vrr = (
1 - exponential_average_factor
) * self.moving_Vrr + exponential_average_factor * Vrr.view(
self.moving_Vrr.size()
)
self.moving_Vii = (
1 - exponential_average_factor
) * self.moving_Vii + exponential_average_factor * Vii.view(
self.moving_Vii.size()
)
self.moving_Vri = (
1 - exponential_average_factor
) * self.moving_Vri + exponential_average_factor * Vri.view(
self.moving_Vri.size()
)
if self.training or (not self.track_running_stats):
input_inferred = input_centred if self.center else input
return c_norm(
input_inferred,
Vrr,
Vii,
Vri,
self.beta,
self.gamma_rr,
self.gamma_ri,
self.gamma_ii,
self.scale,
self.center,
layernorm=False,
dim=self.dim,
)
else: # if we are not training or using running_stats
if self.center:
input_inferred = input - self.moving_mean.view(mu.size())
else:
input_inferred = input
return c_norm(
input_inferred,
self.moving_Vrr,
self.moving_Vii,
self.moving_Vri,
self.beta,
self.gamma_rr,
self.gamma_ri,
self.gamma_ii,
self.scale,
self.center,
layernorm=False,
dim=self.dim,
)
def _retrieve_real_imag(self, tensor, ndim, input_dim):
"""
Function used to retrieve the real and imaginary component of a tensor
according to the dimensions
"""
if self.dim == 1 or ndim == 2:
tensor_real = tensor[:, :input_dim]
tensor_imag = tensor[:, input_dim:]
elif self.dim == -1 and ndim == 3:
tensor_real = tensor[:, :, :input_dim]
tensor_imag = tensor[:, :, input_dim:]
elif self.dim == -1 and ndim == 4:
tensor_real = tensor[:, :, :, :input_dim]
tensor_imag = tensor[:, :, :, input_dim:]
else:
msg = "Retrieve_real_imag expects 2d to 4d inputs. Got " + str(
len(tensor)
)
raise ValueError(msg)
return tensor_real, tensor_imag
def _check_input(self, input_shape):
"""
Checks the input and returns the number of complex values.
"""
if input_shape[self.dim] % 2 == 0:
return input_shape[self.dim] // 2
else:
msg = "ComplexBatchNorm dim must be divisible by 2 ! Got " + str(
input_shape[self.dim]
)
raise ValueError(msg)
[docs]class CLayerNorm(torch.nn.Module):
"""This class is used to instantiate the complex
layer-normalization as introduced by "Deep Complex Networks",
Trabelsi C. et al.
Arguments
---------
input_shape : tuple
Expected shape of the input.
input_size : int
Expected size of the input dimension.
dim : int, optional
It defines the axis that should be normalized. It usually correspond to
the channel dimension (default -1).
eps : float, optional
Term used to stabilize operation (default 1e-4).
scale : bool, optional,
It defines if scaling should be used or not. It is
equivalent to the real-valued batchnormalization scaling (default True).
center : bool, optional
It defines if centering should be used or not. It is
equivalent to the real-valued batchnormalization centering
(default True).
Example
-------
>>> inp_tensor = torch.rand([10, 16, 30])
>>> CBN = CLayerNorm(input_shape=inp_tensor.shape)
>>> out_tensor = CBN(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 16, 30])
"""
def __init__(
self,
input_shape=None,
input_size=None,
dim=-1,
eps=1e-4,
scale=True,
center=True,
):
super().__init__()
self.dim = dim
self.eps = eps
self.scale = scale
self.center = center
if input_size is None:
self.num_complex_features = self._check_input(input_shape)
else:
self.num_complex_features = input_size // 2
if self.scale:
self.gamma_rr = Parameter(torch.empty(self.num_complex_features))
self.gamma_ii = Parameter(torch.empty(self.num_complex_features))
self.gamma_ri = Parameter(torch.empty(self.num_complex_features))
else:
self.register_parameter("gamma_rr", None)
self.register_parameter("gamma_ii", None)
self.register_parameter("gamma_ri", None)
if self.center:
self.beta = Parameter(torch.empty(self.num_complex_features * 2))
else:
self.register_parameter("beta", None)
self.reset_parameters()
[docs] def reset_parameters(self):
# Simply reset all the parameters.
# "Deep Complex Networks" Trabelsi C. et al.
if self.scale:
self.gamma_rr.data.fill_(1 / np.sqrt(2))
self.gamma_ii.data.fill_(1 / np.sqrt(2))
self.gamma_ri.data.zero_()
if self.center:
self.beta.data.zero_()
[docs] def forward(self, input):
input_shape = input.size()
ndim = input.dim()
reduction_axes = list(range(ndim))
del reduction_axes[self.dim]
del reduction_axes[0]
input_dim = input_shape[self.dim] // 2
# Get the mean and center
mu = multi_mean(input, reduction_axes, True)
if self.center:
input_centred = input - mu
else:
input_centred = input
centred_squared = input_centred ** 2
if self.dim == 1 or ndim == 2:
centred_squared_real = centred_squared[:, :input_dim]
centred_squared_imag = centred_squared[:, input_dim:]
centred_real = input_centred[:, :input_dim]
centred_imag = input_centred[:, input_dim:]
elif self.dim == -1 and ndim == 3:
centred_squared_real = centred_squared[:, :, :input_dim]
centred_squared_imag = centred_squared[:, :, input_dim:]
centred_real = input_centred[:, :, :input_dim]
centred_imag = input_centred[:, :, input_dim:]
elif self.dim == -1 and ndim == 4:
centred_squared_real = centred_squared[:, :, :, :input_dim]
centred_squared_imag = centred_squared[:, :, :, input_dim:]
centred_real = input_centred[:, :, :, :input_dim]
centred_imag = input_centred[:, :, :, input_dim:]
else:
centred_squared_real = centred_squared[:, :, :, :, :input_dim]
centred_squared_imag = centred_squared[:, :, :, :, input_dim:]
centred_real = input_centred[:, :, :, :, :input_dim]
centred_imag = input_centred[:, :, :, :, input_dim:]
if self.scale:
Vrr = (
multi_mean(
centred_squared_real, axes=reduction_axes, keepdim=True
)
+ self.eps
)
Vii = (
multi_mean(
centred_squared_imag, axes=reduction_axes, keepdim=True
)
+ self.eps
)
Vri = multi_mean(
centred_real * centred_imag, axes=reduction_axes, keepdim=True
)
else:
Vrr = None
Vii = None
Vri = None
return c_norm(
input_centred,
Vrr,
Vii,
Vri,
self.beta,
self.gamma_rr,
self.gamma_ri,
self.gamma_ii,
self.scale,
self.center,
dim=self.dim,
layernorm=True,
)
def _check_input(self, input_shape):
"""Checks the input and returns the number of complex values.
"""
if input_shape[self.dim] % 2 == 0:
return input_shape[self.dim] // 2
else:
msg = "ComplexBatchNorm dim must be dividble by 2 ! Got " + str(
input_shape[self.dim]
)
raise ValueError(msg)
[docs]def c_norm(
input_centred,
Vrr,
Vii,
Vri,
beta,
gamma_rr,
gamma_ri,
gamma_ii,
scale=True,
center=True,
layernorm=False,
dim=-1,
):
"""This function is used to apply the complex normalization
as introduced by "Deep Complex Networks", Trabelsi C. et al.
Arguments
---------
input_centred : torch.Tensor
It is the tensor to be normalized. The features
dimension is divided by 2 with the first half
corresponding to the real-parts and the second half
to the imaginary parts.
Vrr : torch.Tensor
It is a tensor that contains the covariance between real-parts.
Vii : torch.Tensor
It is a tensor that contains the covariance between imaginary-parts.
Vri : torch.Tensor
It is a tensor that contains the covariance between real-parts and
imaginary-parts.
beta : torch.Tensor
It is a tensor corresponding to the beta parameter on the real-valued
batch-normalization, but in the complex-valued space.
gamma_rr : torch.Tensor
It is a tensor that contains the gamma between real-parts.
gamma_ii : torch.Tensor
It is a tensor that contains the gamma between imaginary-parts.
gamma_ri : torch.Tensor
It is a tensor that contains the gamma between real-parts and
imaginary-parts.
scale : bool, optional
It defines if scaling should be used or not. It is
equivalent to the real-valued batchnormalization
scaling (default True).
center : bool, optional,
It defines if centering should be used or not. It is
equivalent to the real-valued batchnormalization centering
(default True).
layernorm : bool, optional
It defines is c_standardization is called from a layernorm or a
batchnorm layer (default False).
dim : int, optional
It defines the axis that should be considered as the complex-valued
axis (divided by 2 to get r and i) (default -1).
"""
ndim = input_centred.dim()
input_dim = input_centred.size(dim) // 2
if scale:
gamma_broadcast_shape = [1] * ndim
gamma_broadcast_shape[dim] = input_dim
if center:
broadcast_beta_shape = [1] * ndim
broadcast_beta_shape[dim] = input_dim * 2
if scale:
standardized_output = c_standardization(
input_centred, Vrr, Vii, Vri, layernorm, dim=dim
)
# Now we perform the scaling and Shifting of the normalized x using
# the scaling parameter
# [ gamma_rr gamma_ri ]
# Gamma = [ gamma_ri gamma_ii ]
# and the shifting parameter
# Beta = [beta_real beta_imag].T
# where:
# x_real_BN = gamma_rr * x_real_normed +
# gamma_ri * x_imag_normed + beta_real
# x_imag_BN = gamma_ri * x_real_normed +
# gamma_ii * x_imag_normed + beta_imag
broadcast_gamma_rr = gamma_rr.view(gamma_broadcast_shape)
broadcast_gamma_ri = gamma_ri.view(gamma_broadcast_shape)
broadcast_gamma_ii = gamma_ii.view(gamma_broadcast_shape)
cat_gamma_4_real = torch.cat(
[broadcast_gamma_rr, broadcast_gamma_ii], dim=dim
)
cat_gamma_4_imag = torch.cat(
[broadcast_gamma_ri, broadcast_gamma_ri], dim=dim
)
if dim == 0:
centred_real = standardized_output[:input_dim]
centred_imag = standardized_output[input_dim:]
elif dim == 1 or (dim == -1 and ndim == 2):
centred_real = standardized_output[:, :input_dim]
centred_imag = standardized_output[:, input_dim:]
elif dim == -1 and ndim == 3:
centred_real = standardized_output[:, :, :input_dim]
centred_imag = standardized_output[:, :, input_dim:]
elif dim == -1 and ndim == 4:
centred_real = standardized_output[:, :, :, :input_dim]
centred_imag = standardized_output[:, :, :, input_dim:]
else:
centred_real = standardized_output[:, :, :, :, :input_dim]
centred_imag = standardized_output[:, :, :, :, input_dim:]
rolled_standardized_output = torch.cat(
[centred_imag, centred_real], dim=dim
)
if center:
broadcast_beta = beta.view(broadcast_beta_shape)
a = cat_gamma_4_real * standardized_output
b = cat_gamma_4_imag * rolled_standardized_output
return a + b + broadcast_beta
else:
return (
cat_gamma_4_real * standardized_output
+ cat_gamma_4_imag * rolled_standardized_output
)
else:
if center:
broadcast_beta = beta.view(broadcast_beta_shape)
return input_centred + broadcast_beta
else:
return input_centred
[docs]def c_standardization(input_centred, Vrr, Vii, Vri, layernorm=False, dim=-1):
"""This function is used to standardize a centred tensor of
complex numbers (mean of the set must be 0).
Arguments
---------
input_centred : torch.Tensor
It is the tensor to be normalized. The features
dimension is divided by 2 with the first half
corresponding to the real-parts and the second half
to the imaginary parts.
Vrr : torch.Tensor
It is a tensor that contains the covariance between real-parts.
Vii : torch.Tensor
It is a tensor that contains the covariance between imaginary-parts.
Vri : torch.Tensor
It is a tensor that contains the covariance between real-parts and
imaginary-parts.
layernorm : bool, optional
It defines is c_standardization is called from a layernorm or a
batchnorm layer (default False).
dim : int, optional
It defines the axis that should be considered as the complex-valued
axis (divided by 2 to get r and i) (default -1).
"""
ndim = input_centred.dim()
input_dim = input_centred.size(dim) // 2
variances_broadcast = [1] * ndim
variances_broadcast[dim] = input_dim
if layernorm:
variances_broadcast[0] = input_centred.size(0)
# We require the covariance matrix's inverse square root. That requires
# square rooting, followed by inversion (During the computation of square
# root we compute the determinant we'll need for inversion as well).
# tau = Vrr + Vii = Trace. Guaranteed >=0 because Positive-definite matrix
tau = Vrr + Vii
# delta = (Vrr * Vii) - (Vri ** 2) = Determinant
delta = (Vrr * Vii) - (Vri ** 2)
s = delta.sqrt()
t = (tau + 2 * s).sqrt()
# The square root matrix could now be explicitly formed as
# [ Vrr+s Vri ]
# (1/t) [ Vir Vii+s ]
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
# but we don't need to do this immediately since we can also simultaneously
# invert. We can do this because we've already computed the determinant of
# the square root matrix, and can thus invert it using the analytical
# solution for 2x2 matrices
# [ A B ] [ D -B ]
# inv( [ C D ] ) = (1/det) [ -C A ]
# http://mathworld.wolfram.com/MatrixInverse.html
# Thus giving us
# [ Vii+s -Vri ]
# (1/s)(1/t)[ -Vir Vrr+s ]
# So we proceed as follows:
inverse_st = 1.0 / (s * t)
Wrr = (Vii + s) * inverse_st
Wii = (Vrr + s) * inverse_st
Wri = -Vri * inverse_st
# And we have computed the inverse square root matrix W = sqrt(V)!
# Normalization. We multiply, x_normalized = W.x.
# The returned result will be a complex standardized input
# where the real and imaginary parts are obtained as follows:
# x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
# x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred
broadcast_Wrr = Wrr.view(variances_broadcast)
broadcast_Wri = Wri.view(variances_broadcast)
broadcast_Wii = Wii.view(variances_broadcast)
cat_W_4_real = torch.cat([broadcast_Wrr, broadcast_Wii], dim=dim)
cat_W_4_imag = torch.cat([broadcast_Wri, broadcast_Wri], dim=dim)
if dim == 0:
centred_real = input_centred[:input_dim]
centred_imag = input_centred[input_dim:]
elif dim == 1 or (dim == -1 and ndim == 2):
centred_real = input_centred[:, :input_dim]
centred_imag = input_centred[:, input_dim:]
elif dim == -1 and ndim == 3:
centred_real = input_centred[:, :, :input_dim]
centred_imag = input_centred[:, :, input_dim:]
elif dim == -1 and ndim == 4:
centred_real = input_centred[:, :, :, :input_dim]
centred_imag = input_centred[:, :, :, input_dim:]
else:
centred_real = input_centred[:, :, :, :, :input_dim]
centred_imag = input_centred[:, :, :, :, input_dim:]
rolled_input = torch.cat([centred_imag, centred_real], dim=dim)
output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input
# Wrr * x_real_centered | Wii * x_imag_centered
# + Wri * x_imag_centered | Wri * x_real_centered
# -----------------------------------------------
# = output
return output