"""Library implementing complex-valued convolutional neural networks.
Authors
* Titouan Parcollet 2020
"""
import torch
import torch.nn as nn
import logging
import torch.nn.functional as F
from speechbrain.nnet.CNN import get_padding_elem
from speechbrain.nnet.complex_networks.c_ops import (
unitary_init,
complex_init,
affect_conv_init,
complex_conv_op,
)
logger = logging.getLogger(__name__)
[docs]class CConv1d(torch.nn.Module):
"""This function implements complex-valued 1d convolution.
Arguments
---------
out_channels : int
Number of output channels. Please note
that these are complex-valued neurons. If 256
channels are specified, the output dimension
will be 512.
kernel_size : int
Kernel size of the convolutional filters.
stride : int, optional
Stride factor of the convolutional filters (default 1).
dilation : int, optional
Dilation factor of the convolutional filters (default 1).
padding : str, optional
(same, valid, causal). If "valid", no padding is performed.
If "same" and stride is 1, output shape is same as input shape.
"causal" results in causal (dilated) convolutions. (default "same")
padding_mode : str, optional
This flag specifies the type of padding. See torch.nn documentation
for more information (default "reflect").
groups : int, optional
This option specifies the convolutional groups. See torch.nn
documentation for more information (default 1).
bias : bool, optional
If True, the additive bias b is adopted (default True).
init_criterion : str, optional
(glorot, he).
This parameter controls the initialization criterion of the weights.
It is combined with weights_init to build the initialization method of
the complex-valued weights. (default "glorot")
weight_init : str, optional
(complex, unitary).
This parameter defines the initialization procedure of the
complex-valued weights. "complex" will generate random complex-valued
weights following the init_criterion and the complex polar form.
"unitary" will normalize the weights to lie on the unit circle. (default "complex")
More details in: "Deep Complex Networks", Trabelsi C. et al.
Example
-------
>>> inp_tensor = torch.rand([10, 16, 30])
>>> cnn_1d = CConv1d(
... input_shape=inp_tensor.shape, out_channels=12, kernel_size=5
... )
>>> out_tensor = cnn_1d(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 16, 24])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape,
stride=1,
dilation=1,
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
init_criterion="glorot",
weight_init="complex",
):
super().__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.groups = groups
self.bias = bias
self.padding_mode = padding_mode
self.unsqueeze = False
self.init_criterion = init_criterion
self.weight_init = weight_init
self.in_channels = self._check_input(input_shape) // 2
# Managing the weight initialization and bias by directly setting the
# correct function
(self.k_shape, self.w_shape) = self._get_kernel_and_weight_shape()
self.real_weight = torch.nn.Parameter(torch.Tensor(*self.w_shape))
self.imag_weight = torch.nn.Parameter(torch.Tensor(*self.w_shape))
if self.bias:
self.b = torch.nn.Parameter(torch.Tensor(2 * self.out_channels))
self.b.data.fill_(0)
else:
self.b = None
self.winit = {"complex": complex_init, "unitary": unitary_init}[
self.weight_init
]
affect_conv_init(
self.real_weight,
self.imag_weight,
self.kernel_size,
self.winit,
self.init_criterion,
)
[docs] def forward(self, x):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor
(batch, time, channel).
Input to convolve. 3d or 4d tensors are expected.
"""
# (batch, channel, time)
x = x.transpose(1, -1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got %s."
% (self.padding)
)
wx = complex_conv_op(
x,
self.real_weight,
self.imag_weight,
self.b,
stride=self.stride,
padding=0,
dilation=self.dilation,
conv1d=True,
)
wx = wx.transpose(1, -1)
return wx
def _manage_padding(self, x, kernel_size, dilation, stride):
"""This function performs zero-padding on the time axis
such that their lengths is unchanged after the convolution.
Arguments
---------
x : torch.Tensor
Input tensor.
kernel_size : int
Kernel size.
dilation : int
Dilation.
stride : int
Stride.
"""
# Detecting input shape
L_in = x.shape[-1]
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, tuple(padding), mode=self.padding_mode)
return x
def _check_input(self, input_shape):
"""Checks the input and returns the number of input channels.
"""
if len(input_shape) == 3:
in_channels = input_shape[2]
else:
raise ValueError(
"ComplexConv1d expects 3d inputs. Got " + input_shape
)
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
# Check complex format
if in_channels % 2 != 0:
raise ValueError(
"Complex Tensors must have a dimensions dividible by 2."
" input.size()["
+ str(self.channels_axis)
+ "] = "
+ str(self.nb_channels)
)
return in_channels
def _get_kernel_and_weight_shape(self):
""" Returns the kernel size and weight shape for convolutional layers.
"""
ks = self.kernel_size
w_shape = (self.out_channels, self.in_channels) + tuple((ks,))
return ks, w_shape
[docs]class CConv2d(nn.Module):
"""This function implements complex-valued 1d convolution.
Arguments
---------
out_channels : int
Number of output channels. Please note
that these are complex-valued neurons. If 256
channels are specified, the output dimension
will be 512.
kernel_size : int
Kernel size of the convolutional filters.
stride : int, optional
Stride factor of the convolutional filters (default 1).
dilation : int, optional
Dilation factor of the convolutional filters (default 1).
padding : str, optional
(same, valid, causal). If "valid", no padding is performed.
If "same" and stride is 1, output shape is same as input shape.
"causal" results in causal (dilated) convolutions. (default "same")
padding_mode : str, optional
This flag specifies the type of padding (default "reflect").
See torch.nn documentation for more information.
groups : int, optional
This option specifies the convolutional groups (default 1). See torch.nn
documentation for more information.
bias : bool, optional
If True, the additive bias b is adopted (default True).
init_criterion : str , optional
(glorot, he).
This parameter controls the initialization criterion of the weights (default "glorot").
It is combined with weights_init to build the initialization method of
the complex-valued weights.
weight_init : str, optional
(complex, unitary).
This parameter defines the initialization procedure of the
complex-valued weights (default complex). "complex" will generate random complex-valued
weights following the init_criterion and the complex polar form.
"unitary" will normalize the weights to lie on the unit circle.
More details in: "Deep Complex Networks", Trabelsi C. et al.
Example
-------
>>> inp_tensor = torch.rand([10, 16, 30, 30])
>>> cnn_2d = CConv2d(
... input_shape=inp_tensor.shape, out_channels=12, kernel_size=5
... )
>>> out_tensor = cnn_2d(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 16, 30, 24])
"""
def __init__(
self,
out_channels,
kernel_size,
input_shape,
stride=1,
dilation=1,
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
init_criterion="glorot",
weight_init="complex",
):
super().__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.groups = groups
self.bias = bias
self.padding_mode = padding_mode
self.unsqueeze = False
self.init_criterion = init_criterion
self.weight_init = weight_init
# k -> [k,k]
if isinstance(self.kernel_size, int):
self.kernel_size = [self.kernel_size, self.kernel_size]
if isinstance(self.dilation, int):
self.dilation = [self.dilation, self.dilation]
if isinstance(self.stride, int):
self.stride = [self.stride, self.stride]
self.in_channels = self._check_input(input_shape) // 2
# Managing the weight initialization and bias by directly setting the
# correct function
(self.k_shape, self.w_shape) = self._get_kernel_and_weight_shape()
self.real_weight = torch.nn.Parameter(torch.Tensor(*self.w_shape))
self.imag_weight = torch.nn.Parameter(torch.Tensor(*self.w_shape))
if self.bias:
self.b = torch.nn.Parameter(torch.Tensor(2 * self.out_channels))
self.b.data.fill_(0)
else:
self.b = None
self.winit = {"complex": complex_init, "unitary": unitary_init}[
self.weight_init
]
affect_conv_init(
self.real_weight,
self.imag_weight,
self.kernel_size,
self.winit,
self.init_criterion,
)
[docs] def forward(self, x, init_params=False):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor
(batch, time, feature, channels).
Input to convolve. 3d or 4d tensors are expected.
"""
if init_params:
self.init_params(x)
# (batch, channel, feature, time)
x = x.transpose(1, -1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got %s."
% (self.padding)
)
wx = complex_conv_op(
x,
self.real_weight,
self.imag_weight,
self.b,
stride=self.stride,
padding=0,
dilation=self.dilation,
conv1d=False,
)
wx = wx.transpose(1, -1)
return wx
def _get_kernel_and_weight_shape(self):
""" Returns the kernel size and weight shape for convolutional layers.
"""
ks = (self.kernel_size[0], self.kernel_size[1])
w_shape = (self.out_channels, self.in_channels) + (*ks,)
return ks, w_shape
def _manage_padding(self, x, kernel_size, dilation, stride):
"""This function performs zero-padding on the time and frequency axes
such that their lengths is unchanged after the convolution.
Arguments
---------
x : torch.Tensor
Input tensor.
kernel_size : int
Kernel size.
dilation : int
Dilation.
stride: int
Stride.
"""
# Detecting input shape
L_in = x.shape[-1]
# Time padding
padding_time = get_padding_elem(
L_in, stride[-1], kernel_size[-1], dilation[-1]
)
padding_freq = get_padding_elem(
L_in, stride[-2], kernel_size[-2], dilation[-2]
)
padding = padding_time + padding_freq
# Applying padding
x = nn.functional.pad(x, tuple(padding), mode=self.padding_mode)
return x
def _check_input(self, input_shape):
"""Checks the input and returns the number of input channels.
"""
if len(input_shape) == 3:
self.unsqueeze = True
in_channels = 1
elif len(input_shape) == 4:
in_channels = input_shape[3]
else:
raise ValueError("Expected 3d or 4d inputs. Got " + input_shape)
# Kernel size must be odd
if self.kernel_size[0] % 2 == 0 or self.kernel_size[1] % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
# Check complex format
if in_channels % 2 != 0:
raise ValueError(
"Complex Tensors must have a dimensions dividible by 2."
" input.size()["
+ str(self.channels_axis)
+ "] = "
+ str(self.nb_channels)
)
return in_channels