"""Library implementing normalization.
Authors
* Mirco Ravanelli 2020
"""
import torch
import torch.nn as nn
[docs]class BatchNorm1d(nn.Module):
"""Applies 1d batch normalization to the input tensor.
Arguments
---------
input_shape : tuple
The expected shape of the input. Alternatively, use ``input_size``.
input_size : int
The expected size of the input. Alternatively, use ``input_shape``.
eps : float
This value is added to std deviation estimation to improve the numerical
stability.
momentum : float
It is a value used for the running_mean and running_var computation.
affine : bool
When set to True, the affine parameters are learned.
track_running_stats : bool
When set to True, this module tracks the running mean and variance,
and when set to False, this module does not track such statistics.
combine_batch_time : bool
When true, it combines batch an time axis.
Example
-------
>>> input = torch.randn(100, 10)
>>> norm = BatchNorm1d(input_shape=input.shape)
>>> output = norm(input)
>>> output.shape
torch.Size([100, 10])
"""
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
combine_batch_time=False,
skip_transpose=False,
):
super().__init__()
self.combine_batch_time = combine_batch_time
self.skip_transpose = skip_transpose
if input_size is None and skip_transpose:
input_size = input_shape[1]
elif input_size is None:
input_size = input_shape[-1]
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
[docs] def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, [channels])
input to normalize. 2d or 3d tensors are expected in input
4d tensors can be used when combine_dims=True.
"""
shape_or = x.shape
if self.combine_batch_time:
if x.ndim == 3:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
else:
x = x.reshape(
shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
)
elif not self.skip_transpose:
x = x.transpose(-1, 1)
x_n = self.norm(x)
if self.combine_batch_time:
x_n = x_n.reshape(shape_or)
elif not self.skip_transpose:
x_n = x_n.transpose(1, -1)
return x_n
[docs]class BatchNorm2d(nn.Module):
"""Applies 2d batch normalization to the input tensor.
Arguments
---------
input_shape : tuple
The expected shape of the input. Alternatively, use ``input_size``.
input_size : int
The expected size of the input. Alternatively, use ``input_shape``.
eps : float
This value is added to std deviation estimation to improve the numerical
stability.
momentum : float
It is a value used for the running_mean and running_var computation.
affine : bool
When set to True, the affine parameters are learned.
track_running_stats : bool
When set to True, this module tracks the running mean and variance,
and when set to False, this module does not track such statistics.
Example
-------
>>> input = torch.randn(100, 10, 5, 20)
>>> norm = BatchNorm2d(input_shape=input.shape)
>>> output = norm(input)
>>> output.shape
torch.Size([100, 10, 5, 20])
"""
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super().__init__()
if input_shape is None and input_size is None:
raise ValueError("Expected input_shape or input_size as input")
if input_size is None:
input_size = input_shape[-1]
self.norm = nn.BatchNorm2d(
input_size,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
[docs] def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, channel1, channel2)
input to normalize. 4d tensors are expected.
"""
x = x.transpose(-1, 1)
x_n = self.norm(x)
x_n = x_n.transpose(1, -1)
return x_n
[docs]class LayerNorm(nn.Module):
"""Applies layer normalization to the input tensor.
Arguments
---------
input_shape : tuple
The expected shape of the input.
eps : float
This value is added to std deviation estimation to improve the numerical
stability.
elementwise_affine : bool
If True, this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example
-------
>>> input = torch.randn(100, 101, 128)
>>> norm = LayerNorm(input_shape=input.shape)
>>> output = norm(input)
>>> output.shape
torch.Size([100, 101, 128])
"""
def __init__(
self,
input_size=None,
input_shape=None,
eps=1e-05,
elementwise_affine=True,
):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if input_shape is not None:
input_size = input_shape[2:]
self.norm = torch.nn.LayerNorm(
input_size,
eps=self.eps,
elementwise_affine=self.elementwise_affine,
)
[docs] def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, channels)
input to normalize. 3d or 4d tensors are expected.
"""
return self.norm(x)
[docs]class InstanceNorm1d(nn.Module):
"""Applies 1d instance normalization to the input tensor.
Arguments
---------
input_shape : tuple
The expected shape of the input. Alternatively, use ``input_size``.
input_size : int
The expected size of the input. Alternatively, use ``input_shape``.
eps : float
This value is added to std deviation estimation to improve the numerical
stability.
momentum : float
It is a value used for the running_mean and running_var computation.
track_running_stats : bool
When set to True, this module tracks the running mean and variance,
and when set to False, this module does not track such statistics.
affine : bool
A boolean value that when set to True, this module has learnable
affine parameters, initialized the same way as done for
batch normalization. Default: False.
Example
-------
>>> input = torch.randn(100, 10, 20)
>>> norm = InstanceNorm1d(input_shape=input.shape)
>>> output = norm(input)
>>> output.shape
torch.Size([100, 10, 20])
"""
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
track_running_stats=True,
affine=False,
):
super().__init__()
if input_shape is None and input_size is None:
raise ValueError("Expected input_shape or input_size as input")
if input_size is None:
input_size = input_shape[-1]
self.norm = nn.InstanceNorm1d(
input_size,
eps=eps,
momentum=momentum,
track_running_stats=track_running_stats,
affine=affine,
)
[docs] def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, channels)
input to normalize. 3d tensors are expected.
"""
x = x.transpose(-1, 1)
x_n = self.norm(x)
x_n = x_n.transpose(1, -1)
return x_n
[docs]class InstanceNorm2d(nn.Module):
"""Applies 2d instance normalization to the input tensor.
Arguments
---------
input_shape : tuple
The expected shape of the input. Alternatively, use ``input_size``.
input_size : int
The expected size of the input. Alternatively, use ``input_shape``.
eps : float
This value is added to std deviation estimation to improve the numerical
stability.
momentum : float
It is a value used for the running_mean and running_var computation.
track_running_stats : bool
When set to True, this module tracks the running mean and variance,
and when set to False, this module does not track such statistics.
affine : bool
A boolean value that when set to True, this module has learnable
affine parameters, initialized the same way as done for
batch normalization. Default: False.
Example
-------
>>> input = torch.randn(100, 10, 20, 2)
>>> norm = InstanceNorm2d(input_shape=input.shape)
>>> output = norm(input)
>>> output.shape
torch.Size([100, 10, 20, 2])
"""
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
track_running_stats=True,
affine=False,
):
super().__init__()
if input_shape is None and input_size is None:
raise ValueError("Expected input_shape or input_size as input")
if input_size is None:
input_size = input_shape[-1]
self.norm = nn.InstanceNorm2d(
input_size,
eps=eps,
momentum=momentum,
track_running_stats=track_running_stats,
affine=affine,
)
[docs] def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, channel1, channel2)
input to normalize. 4d tensors are expected.
"""
x = x.transpose(-1, 1)
x_n = self.norm(x)
x_n = x_n.transpose(1, -1)
return x_n