"""This library implements different operations needed by quaternion-
valued architectures.
This work is inspired by:
"Quaternion neural networks" - Parcollet T.
"Quaternion recurrent neural networks" - Parcollet T. et al.
"Quaternion convolutional neural networks for end-to-end automatic speech
recognition" - Parcollet T. et al.
"Deep quaternion networks" - Gaudet Chase J. et al.
Authors
* Titouan Parcollet 2020
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import chi
from torch.autograd import Variable
[docs]
class QuaternionLinearCustomBackward(torch.autograd.Function):
"""This class redefine the backpropagation of a quaternion linear layer
(not a spinor layer). By doing so, we can save up to 4x memory, but it
is also 2x slower than 'quaternion_linear_op'. It should be used
within speechbrain.nnet.quaternion_networks.linear.QuaternionLinear.
"""
[docs]
@staticmethod
def forward(ctx, input, r_weight, i_weight, j_weight, k_weight, bias):
"""
Applies a quaternion linear transformation to the incoming data:
It is important to notice that the forward phase of a QNN is defined
as W * Inputs (with * equal to the Hamilton product). The constructed
cat_kernels_4_quaternion is a modified version of the quaternion
representation so when we do torch.mm(Input,W) it's equivalent
to W * Inputs.
Arguments
---------
ctx : PyTorch context object
Used to save the context necessary to perform a backwards pass.
input : torch.Tensor
Quaternion input tensor to be transformed. Shape: [batch*time, X].
r_weight : torch.Parameter
Real part of the quaternion weight matrix of this layer.
i_weight : torch.Parameter
First imaginary part of the quaternion weight matrix of this layer.
j_weight : torch.Parameter
Second imaginary part of the quaternion weight matrix of this layer.
k_weight : torch.Parameter
Third imaginary part of the quaternion weight matrix of this layer.
bias : torch.Parameter
Returns
-------
The linearly transformed quaternions
"""
ctx.save_for_backward(
input, r_weight, i_weight, j_weight, k_weight, bias
)
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=0
)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=0
)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=0
)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=0
)
cat_kernels_4_quaternion = torch.cat(
[
cat_kernels_4_r,
cat_kernels_4_i,
cat_kernels_4_j,
cat_kernels_4_k,
],
dim=1,
)
if bias.requires_grad:
return torch.addmm(bias, input, cat_kernels_4_quaternion)
else:
return torch.mm(input, cat_kernels_4_quaternion)
# This function has only a single output, so it gets only one gradient
[docs]
@staticmethod
def backward(ctx, grad_output):
"""
Run the backward phase of the forward call defined above. This
implementation follows the quaternion backpropagation of a quaternion
layer that can be found in "Quaternion neural networks" - Parcollet T.
Page 48.
Arguments
---------
ctx : Pytorch context object
Contains saved weights and bias
grad_output : torch.Tensor
The output of the forward part
Returns
-------
The corresponding gradients of this op
"""
input, r_weight, i_weight, j_weight, k_weight, bias = ctx.saved_tensors
grad_input = grad_weight_r = grad_weight_i = grad_weight_j = (
grad_weight_k
) = grad_bias = None
input_r = torch.cat([r_weight, -i_weight, -j_weight, -k_weight], dim=0)
input_i = torch.cat([i_weight, r_weight, -k_weight, j_weight], dim=0)
input_j = torch.cat([j_weight, k_weight, r_weight, -i_weight], dim=0)
input_k = torch.cat([k_weight, -j_weight, i_weight, r_weight], dim=0)
cat_kernels_4_quaternion_T = Variable(
torch.cat([input_r, input_i, input_j, input_k], dim=1).permute(
1, 0
),
requires_grad=False,
)
nb_hidden = input.size()[-1]
r = input.narrow(1, 0, nb_hidden // 4)
i = input.narrow(1, nb_hidden // 4, nb_hidden // 4)
j = input.narrow(1, nb_hidden // 2, nb_hidden // 4)
k = input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4)
input_r = torch.cat([r, -i, -j, -k], dim=0)
input_i = torch.cat([i, r, -k, j], dim=0)
input_j = torch.cat([j, k, r, -i], dim=0)
input_k = torch.cat([k, -j, i, r], dim=0)
input_mat = Variable(
torch.cat([input_r, input_i, input_j, input_k], dim=1),
requires_grad=False,
)
nb_hidden = grad_output.size()[-1]
r = grad_output.narrow(1, 0, nb_hidden // 4)
i = grad_output.narrow(1, nb_hidden // 4, nb_hidden // 4)
j = grad_output.narrow(1, nb_hidden // 2, nb_hidden // 4)
k = grad_output.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4)
input_r = torch.cat([r, i, j, k], dim=1)
input_i = torch.cat([-i, r, k, -j], dim=1)
input_j = torch.cat([-j, -k, r, i], dim=1)
input_k = torch.cat([-k, j, -i, r], dim=1)
grad_mat = torch.cat([input_r, input_i, input_j, input_k], dim=0)
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(cat_kernels_4_quaternion_T)
if ctx.needs_input_grad[1]:
grad_weight = grad_mat.permute(1, 0).mm(input_mat).permute(1, 0)
unit_size_x = r_weight.size(0)
unit_size_y = r_weight.size(1)
grad_weight_r = grad_weight.narrow(0, 0, unit_size_x).narrow(
1, 0, unit_size_y
)
grad_weight_i = grad_weight.narrow(0, 0, unit_size_x).narrow(
1, unit_size_y, unit_size_y
)
grad_weight_j = grad_weight.narrow(0, 0, unit_size_x).narrow(
1, unit_size_y * 2, unit_size_y
)
grad_weight_k = grad_weight.narrow(0, 0, unit_size_x).narrow(
1, unit_size_y * 3, unit_size_y
)
if ctx.needs_input_grad[5]:
grad_bias = grad_output.sum(0).squeeze(0)
return (
grad_input,
grad_weight_r,
grad_weight_i,
grad_weight_j,
grad_weight_k,
grad_bias,
)
[docs]
def quaternion_linear_op(input, r_weight, i_weight, j_weight, k_weight, bias):
"""
Applies a quaternion linear transformation to the incoming data:
It is important to notice that the forward phase of a QNN is defined
as W * Inputs (with * equal to the Hamilton product). The constructed
cat_kernels_4_quaternion is a modified version of the quaternion
representation so when we do torch.mm(Input,W) it's equivalent
to W * Inputs.
Arguments
---------
input : torch.Tensor
Quaternion input tensor to be transformed.
r_weight : torch.Parameter
Real part of the quaternion weight matrix of this layer.
i_weight : torch.Parameter
First imaginary part of the quaternion weight matrix of this layer.
j_weight : torch.Parameter
Second imaginary part of the quaternion weight matrix of this layer.
k_weight : torch.Parameter
Third imaginary part of the quaternion weight matrix of this layer.
bias : torch.Parameter
Returns
-------
The linearly transformed quaternions
"""
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=0
)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=0
)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=0
)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=0
)
cat_kernels_4_quaternion = torch.cat(
[cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k],
dim=1,
)
# If the input is already [batch*time, N]
if input.dim() == 2:
if bias.requires_grad:
return torch.addmm(bias, input, cat_kernels_4_quaternion)
else:
return torch.mm(input, cat_kernels_4_quaternion)
else:
output = torch.matmul(input, cat_kernels_4_quaternion)
if bias.requires_grad:
return output + bias
else:
return output
[docs]
def quaternion_linear_rotation_op(
input, r_weight, i_weight, j_weight, k_weight, bias, scale, zero_kernel
):
"""
Applies a quaternion rotation transformation to the incoming data:
The rotation W*x*W^t can be replaced by R*x following:
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
Works for unitary and non-unitary weights (they will be normalized).
The initial size of the input must be a multiple of 4 with the real part
equal to zero. Rotations only affect the vector part of a quaternion.
Arguments
---------
input : torch.Tensor
Quaternion input tensor to be transformed.
r_weight : torch.Parameter
Real part of the quaternion weight matrix of this layer.
i_weight : torch.Parameter
First imaginary part of the quaternion weight matrix of this layer.
j_weight : torch.Parameter
Second imaginary part of the quaternion weight matrix of this layer.
k_weight : torch.Parameter
Third imaginary part of the quaternion weight matrix of this layer.
bias : torch.Parameter
scale : torch.Parameter
In the context of a spinor neural network, multiple rotations of
the input vector x are performed and summed. Hence, the norm of
the output vector always increases with the number of layers, making
the neural network instable with deep configurations. The scale
parameters are learnable parameters that acts like gates by multiplying
the output vector with a small trainable parameter.
zero_kernel : torch.Parameter
The zero kernel is simply a tensor of zeros with require grad = False.
Its shape is equivalent to a quaternion component shape. In fact,
it is only needed to make the dimensions match when using the rotation
matrix : https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
Returns
-------
The linearly rotated quaternions
"""
# First we normalise the quaternion weights. Only unit quaternions are
# valid rotations.
square_r = r_weight * r_weight
square_i = i_weight * i_weight
square_j = j_weight * j_weight
square_k = k_weight * k_weight
norm = torch.sqrt(square_r + square_i + square_j + square_k) + 0.0001
r_n_weight = r_weight / norm
i_n_weight = i_weight / norm
j_n_weight = j_weight / norm
k_n_weight = k_weight / norm
# See https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for
# the rest of the equations.
norm_factor = 2.0
square_i = norm_factor * (i_n_weight * i_n_weight)
square_j = norm_factor * (j_n_weight * j_n_weight)
square_k = norm_factor * (k_n_weight * k_n_weight)
ri = norm_factor * r_n_weight * i_n_weight
rj = norm_factor * r_n_weight * j_n_weight
rk = norm_factor * r_n_weight * k_n_weight
ij = norm_factor * i_n_weight * j_n_weight
ik = norm_factor * i_n_weight * k_n_weight
jk = norm_factor * j_n_weight * k_n_weight
if scale.requires_grad:
rot_kernel_1 = torch.cat(
[
zero_kernel,
scale * (1.0 - (square_j + square_k)),
scale * (ij - rk),
scale * (ik + rj),
],
dim=1,
)
rot_kernel_2 = torch.cat(
[
zero_kernel,
scale * (ij + rk),
scale * (1.0 - (square_i + square_k)),
scale * (jk - ri),
],
dim=1,
)
rot_kernel_3 = torch.cat(
[
zero_kernel,
scale * (ik - rj),
scale * (jk + ri),
scale * (1.0 - (square_i + square_j)),
],
dim=1,
)
else:
rot_kernel_1 = torch.cat(
[zero_kernel, (1.0 - (square_j + square_k)), (ij - rk), (ik + rj)],
dim=1,
)
rot_kernel_2 = torch.cat(
[zero_kernel, (ij + rk), (1.0 - (square_i + square_k)), (jk - ri)],
dim=1,
)
rot_kernel_3 = torch.cat(
[zero_kernel, (ik - rj), (jk + ri), (1.0 - (square_i + square_j))],
dim=1,
)
zero_kernel2 = torch.cat(
[zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1
)
global_rot_kernel = torch.cat(
[zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0
)
if input.dim() == 2:
if bias.requires_grad:
return torch.addmm(bias, input, global_rot_kernel)
else:
return torch.mm(input, global_rot_kernel)
else:
output = torch.matmul(input, global_rot_kernel)
if bias.requires_grad:
return output + bias
else:
return output
[docs]
def quaternion_conv_rotation_op(
input,
r_weight,
i_weight,
j_weight,
k_weight,
bias,
scale,
zero_kernel,
stride: int,
padding: int,
groups: int,
dilation: int,
conv1d: bool,
):
"""
Applies a quaternion rotation transformation to the incoming data:
The rotation W*x*W^t can be replaced by R*x following:
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
Works for unitary and non-unitary weights (they will be normalized).
The initial size of the input must be a multiple of 4 with the real part
equal to zero. Rotations only affect the vector part of a quaternion.
Arguments
---------
input : torch.Tensor
Quaternion input tensor to be transformed.
r_weight : torch.Parameter
Real part of the quaternion weight matrix of this layer.
i_weight : torch.Parameter
First imaginary part of the quaternion weight matrix of this layer.
j_weight : torch.Parameter
Second imaginary part of the quaternion weight matrix of this layer.
k_weight : torch.Parameter
Third imaginary part of the quaternion weight matrix of this layer.
bias : torch.Parameter
scale : torch.Parameter
In the context of a spinor neural network, multiple rotations of
the input vector x are performed and summed. Hence, the norm of
the output vector always increases with the number of layers, making
the neural network instable with deep configurations. The scale
parameters are learnable parameters that acts like gates by multiplying
the output vector with a small trainable parameter.
zero_kernel : torch.Parameter
The zero kernel is simply a tensor of zeros with require grad = False.
Its shape is equivalent to a quaternion component shape. In fact,
it is only needed to make the dimensions match when using the rotation
matrix : https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
stride : int
Stride factor of the convolutional filters.
padding : int
Amount of padding. See torch.nn documentation for more information.
groups : int
This option specifies the convolutional groups. See torch.nn
documentation for more information.
dilation : int
Dilation factor of the convolutional filters.
conv1d : bool
If true, a 1D convolution operation will be applied. Otherwise, a 2D
convolution is called.
Returns
-------
The rotated quaternion inputs
"""
square_r = r_weight * r_weight
square_i = i_weight * i_weight
square_j = j_weight * j_weight
square_k = k_weight * k_weight
norm = torch.sqrt(square_r + square_i + square_j + square_k + 0.0001)
r_n_weight = r_weight / norm
i_n_weight = i_weight / norm
j_n_weight = j_weight / norm
k_n_weight = k_weight / norm
norm_factor = 2.0
square_i = norm_factor * (i_n_weight * i_n_weight)
square_j = norm_factor * (j_n_weight * j_n_weight)
square_k = norm_factor * (k_n_weight * k_n_weight)
ri = norm_factor * r_n_weight * i_n_weight
rj = norm_factor * r_n_weight * j_n_weight
rk = norm_factor * r_n_weight * k_n_weight
ij = norm_factor * i_n_weight * j_n_weight
ik = norm_factor * i_n_weight * k_n_weight
jk = norm_factor * j_n_weight * k_n_weight
if scale.requires_grad:
rot_kernel_1 = torch.cat(
[
zero_kernel,
scale * (1.0 - (square_j + square_k)),
scale * (ij - rk),
scale * (ik + rj),
],
dim=1,
)
rot_kernel_2 = torch.cat(
[
zero_kernel,
scale * (ij + rk),
scale * (1.0 - (square_i + square_k)),
scale * (jk - ri),
],
dim=1,
)
rot_kernel_3 = torch.cat(
[
zero_kernel,
scale * (ik - rj),
scale * (jk + ri),
scale * (1.0 - (square_i + square_j)),
],
dim=1,
)
else:
rot_kernel_1 = torch.cat(
[zero_kernel, (1.0 - (square_j + square_k)), (ij - rk), (ik + rj)],
dim=1,
)
rot_kernel_2 = torch.cat(
[zero_kernel, (ij + rk), (1.0 - (square_i + square_k)), (jk - ri)],
dim=1,
)
rot_kernel_3 = torch.cat(
[zero_kernel, (ik - rj), (jk + ri), (1.0 - (square_i + square_j))],
dim=1,
)
zero_kernel2 = torch.cat(
[zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1
)
global_rot_kernel = torch.cat(
[zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0
)
if conv1d:
return F.conv1d(
input=input,
weight=global_rot_kernel,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
else:
return F.conv2d(
input=input,
weight=global_rot_kernel,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
[docs]
def quaternion_conv_op(
input,
r_weight,
i_weight,
j_weight,
k_weight,
bias,
stride: int,
padding: int,
groups: int,
dilation: int,
conv1d: bool,
):
"""
Applies a quaternion convolution transformation to the incoming data:
It is important to notice that the forward phase of a QCNN is defined
as W * Inputs (with * equal to the Hamilton product). The constructed
cat_kernels_4_quaternion is a modified version of the quaternion
representation so when we do torch.mm(Input,W) it's equivalent
to W * Inputs.
Arguments
---------
input : torch.Tensor
Quaternion input tensor to be transformed.
r_weight : torch.Parameter
Real part of the quaternion weight matrix of this layer.
i_weight : torch.Parameter
First imaginary part of the quaternion weight matrix of this layer.
j_weight : torch.Parameter
Second imaginary part of the quaternion weight matrix of this layer.
k_weight : torch.Parameter
Third imaginary part of the quaternion weight matrix of this layer.
bias : torch.Parameter
stride : int
Stride factor of the convolutional filters.
padding : int
Amount of padding. See torch.nn documentation for more information.
groups : int
This option specifies the convolutional groups. See torch.nn
documentation for more information.
dilation : int
Dilation factor of the convolutional filters.
conv1d : bool
If true, a 1D convolution operation will be applied. Otherwise, a 2D
convolution is called.
Returns
-------
The convolved quaternion inputs
"""
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=1
)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=1
)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=1
)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=1
)
cat_kernels_4_quaternion = torch.cat(
[cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k],
dim=0,
)
if conv1d:
return F.conv1d(
input=input,
weight=cat_kernels_4_quaternion,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
else:
return F.conv2d(
input=input,
weight=cat_kernels_4_quaternion,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
[docs]
def quaternion_init(
in_features, out_features, kernel_size=None, criterion="glorot"
):
"""Returns a matrix of quaternion numbers initialized with the method
described in "Quaternion Recurrent Neural Network " - Parcollet T.
Arguments
---------
in_features : int
Number of real values of the input layer (quaternion // 4).
out_features : int
Number of real values of the output layer (quaternion // 4).
kernel_size : int
Kernel_size for convolutional layers (ex: (3,3)).
criterion : str
(glorot, he)
Returns
-------
Matrix of initialized quaternion numbers
"""
# We set the numpy seed equal to the torch seed for reproducibility
# Indeed we use numpy and scipy here. We need % (2**31-1) or, if the
# seed hasn't been set by the used in the YAML file, torch will generate
# a double that would be to big for numpy.
np.random.seed(seed=torch.initial_seed() % (2**31 - 1))
if kernel_size is not None:
receptive_field = np.prod(kernel_size)
fan_in = in_features * receptive_field
fan_out = out_features * receptive_field
else:
fan_in = in_features
fan_out = out_features
if criterion == "glorot":
s = 1.0 / np.sqrt(2 * (fan_in + fan_out))
else:
s = 1.0 / np.sqrt(2 * fan_in)
# Generating randoms and purely imaginary quaternions :
if kernel_size is None:
kernel_shape = (in_features, out_features)
else:
if type(kernel_size) is int:
kernel_shape = (out_features, in_features) + tuple((kernel_size,))
else:
kernel_shape = (out_features, in_features) + (*kernel_size,)
modulus = torch.from_numpy(chi.rvs(4, loc=0, scale=s, size=kernel_shape))
number_of_weights = np.prod(kernel_shape)
v_i = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
v_j = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
v_k = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
# Purely imaginary quaternions unitary
for i in range(0, number_of_weights):
norm = torch.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001
v_i[i] /= norm
v_j[i] /= norm
v_k[i] /= norm
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)
phase = torch.rand(kernel_shape).uniform_(-math.pi, math.pi)
weight_r = modulus * torch.cos(phase)
weight_i = modulus * v_i * torch.sin(phase)
weight_j = modulus * v_j * torch.sin(phase)
weight_k = modulus * v_k * torch.sin(phase)
return (weight_r, weight_i, weight_j, weight_k)
[docs]
def unitary_init(in_features, out_features, kernel_size=None, criterion="he"):
"""Returns a matrix of unitary quaternion numbers.
Arguments
---------
in_features : int
Number of real values of the input layer (quaternion // 4).
out_features : int
Number of real values of the output layer (quaternion // 4).
kernel_size : int
Kernel_size for convolutional layers (ex: (3,3)).
criterion : str
(glorot, he)
Returns
-------
Matrix of unitary quaternion numbers.
"""
if kernel_size is None:
kernel_shape = (in_features, out_features)
else:
if type(kernel_size) is int:
kernel_shape = (out_features, in_features) + tuple((kernel_size,))
else:
kernel_shape = (out_features, in_features) + (*kernel_size,)
number_of_weights = np.prod(kernel_shape)
v_r = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
v_i = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
v_j = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
v_k = torch.FloatTensor(number_of_weights).uniform_(-1, 1)
# Unitary quaternion
for i in range(0, number_of_weights):
norm = (
torch.sqrt(v_r[i] ** 2 + v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2)
+ 0.0001
)
v_r[i] /= norm
v_i[i] /= norm
v_j[i] /= norm
v_k[i] /= norm
v_r = v_r.reshape(kernel_shape)
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)
return (v_r, v_i, v_j, v_k)
[docs]
def affect_init(
r_weight, i_weight, j_weight, k_weight, init_func, init_criterion
):
"""Applies the weight initialization function given to the parameters.
Arguments
---------
r_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
i_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
j_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
k_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
init_func : function
(unitary_init, quaternion_init)
init_criterion : str
(glorot, he)
"""
r, i, j, k = init_func(
r_weight.size(0), r_weight.size(1), None, init_criterion
)
r_weight.data = r.type_as(r_weight.data)
i_weight.data = i.type_as(i_weight.data)
j_weight.data = j.type_as(j_weight.data)
k_weight.data = k.type_as(k_weight.data)
[docs]
def affect_conv_init(
r_weight,
i_weight,
j_weight,
k_weight,
kernel_size,
init_func,
init_criterion,
):
"""Applies the weight initialization function given to the parameters.
This is specifically written for convolutional layers.
Arguments
---------
r_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
i_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
j_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
k_weight : torch.Parameters
(nb_quaternion_in, nb_quaternion_out)
kernel_size : int
Kernel size.
init_func : function
(unitary_init, quaternion_init)
init_criterion : str
(glorot, he)
"""
in_channels = r_weight.size(1)
out_channels = r_weight.size(0)
r, i, j, k = init_func(
in_channels,
out_channels,
kernel_size=kernel_size,
criterion=init_criterion,
)
r_weight.data = r.type_as(r_weight.data)
i_weight.data = i.type_as(i_weight.data)
j_weight.data = j.type_as(j_weight.data)
k_weight.data = k.type_as(k_weight.data)
[docs]
def renorm_quaternion_weights_inplace(
r_weight, i_weight, j_weight, k_weight, max_norm
):
"""Renorms the magnitude of the quaternion-valued weights.
Arguments
---------
r_weight : torch.Parameter
i_weight : torch.Parameter
j_weight : torch.Parameter
k_weight : torch.Parameter
max_norm : float
The maximum norm of the magnitude of the quaternion weights
"""
weight_magnitude = torch.sqrt(
r_weight.data**2
+ i_weight.data**2
+ j_weight.data**2
+ k_weight.data**2
)
renormed_weight_magnitude = torch.renorm(
weight_magnitude, p=2, dim=0, maxnorm=max_norm
)
factor = renormed_weight_magnitude / weight_magnitude
r_weight.data *= factor
i_weight.data *= factor
j_weight.data *= factor
k_weight.data *= factor