Source code for speechbrain.nnet.quaternion_networks.q_ops

"""This library implements different operations needed by quaternion-
valued architectures.
This work is inspired by:
"Quaternion neural networks" - Parcollet T.
"Quaternion recurrent neural networks" - Parcollet T. et al.
"Quaternion convolutional neural networks for end-to-end automatic speech
recognition" - Parcollet T. et al.
"Deep quaternion networks" - Gaudet Chase J. et al.

 * Titouan Parcollet 2020

import math

import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import chi
from torch.autograd import Variable

[docs] class QuaternionLinearCustomBackward(torch.autograd.Function): """This class redefine the backpropagation of a quaternion linear layer (not a spinor layer). By doing so, we can save up to 4x memory, but it is also 2x slower than 'quaternion_linear_op'. It should be used within speechbrain.nnet.quaternion_networks.linear.QuaternionLinear. """
[docs] @staticmethod def forward(ctx, input, r_weight, i_weight, j_weight, k_weight, bias): """ Applies a quaternion linear transformation to the incoming data: It is important to notice that the forward phase of a QNN is defined as W * Inputs (with * equal to the Hamilton product). The constructed cat_kernels_4_quaternion is a modified version of the quaternion representation so when we do,W) it's equivalent to W * Inputs. Arguments --------- ctx : PyTorch context object Used to save the context necessary to perform a backwards pass. input : torch.Tensor Quaternion input tensor to be transformed. Shape: [batch*time, X]. r_weight : torch.Parameter Real part of the quaternion weight matrix of this layer. i_weight : torch.Parameter First imaginary part of the quaternion weight matrix of this layer. j_weight : torch.Parameter Second imaginary part of the quaternion weight matrix of this layer. k_weight : torch.Parameter Third imaginary part of the quaternion weight matrix of this layer. bias : torch.Parameter Returns ------- The linearly transformed quaternions """ ctx.save_for_backward( input, r_weight, i_weight, j_weight, k_weight, bias ) cat_kernels_4_r = [r_weight, -i_weight, -j_weight, -k_weight], dim=0 ) cat_kernels_4_i = [i_weight, r_weight, -k_weight, j_weight], dim=0 ) cat_kernels_4_j = [j_weight, k_weight, r_weight, -i_weight], dim=0 ) cat_kernels_4_k = [k_weight, -j_weight, i_weight, r_weight], dim=0 ) cat_kernels_4_quaternion = [ cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k, ], dim=1, ) if bias.requires_grad: return torch.addmm(bias, input, cat_kernels_4_quaternion) else: return, cat_kernels_4_quaternion)
# This function has only a single output, so it gets only one gradient
[docs] @staticmethod def backward(ctx, grad_output): """ Run the backward phase of the forward call defined above. This implementation follows the quaternion backpropagation of a quaternion layer that can be found in "Quaternion neural networks" - Parcollet T. Page 48. Arguments --------- ctx : Pytorch context object Contains saved weights and bias grad_output : torch.Tensor The output of the forward part Returns ------- The corresponding gradients of this op """ input, r_weight, i_weight, j_weight, k_weight, bias = ctx.saved_tensors grad_input = grad_weight_r = grad_weight_i = grad_weight_j = ( grad_weight_k ) = grad_bias = None input_r =[r_weight, -i_weight, -j_weight, -k_weight], dim=0) input_i =[i_weight, r_weight, -k_weight, j_weight], dim=0) input_j =[j_weight, k_weight, r_weight, -i_weight], dim=0) input_k =[k_weight, -j_weight, i_weight, r_weight], dim=0) cat_kernels_4_quaternion_T = Variable([input_r, input_i, input_j, input_k], dim=1).permute( 1, 0 ), requires_grad=False, ) nb_hidden = input.size()[-1] r = input.narrow(1, 0, nb_hidden // 4) i = input.narrow(1, nb_hidden // 4, nb_hidden // 4) j = input.narrow(1, nb_hidden // 2, nb_hidden // 4) k = input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4) input_r =[r, -i, -j, -k], dim=0) input_i =[i, r, -k, j], dim=0) input_j =[j, k, r, -i], dim=0) input_k =[k, -j, i, r], dim=0) input_mat = Variable([input_r, input_i, input_j, input_k], dim=1), requires_grad=False, ) nb_hidden = grad_output.size()[-1] r = grad_output.narrow(1, 0, nb_hidden // 4) i = grad_output.narrow(1, nb_hidden // 4, nb_hidden // 4) j = grad_output.narrow(1, nb_hidden // 2, nb_hidden // 4) k = grad_output.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4) input_r =[r, i, j, k], dim=1) input_i =[-i, r, k, -j], dim=1) input_j =[-j, -k, r, i], dim=1) input_k =[-k, j, -i, r], dim=1) grad_mat =[input_r, input_i, input_j, input_k], dim=0) if ctx.needs_input_grad[0]: grad_input = if ctx.needs_input_grad[1]: grad_weight = grad_mat.permute(1, 0).mm(input_mat).permute(1, 0) unit_size_x = r_weight.size(0) unit_size_y = r_weight.size(1) grad_weight_r = grad_weight.narrow(0, 0, unit_size_x).narrow( 1, 0, unit_size_y ) grad_weight_i = grad_weight.narrow(0, 0, unit_size_x).narrow( 1, unit_size_y, unit_size_y ) grad_weight_j = grad_weight.narrow(0, 0, unit_size_x).narrow( 1, unit_size_y * 2, unit_size_y ) grad_weight_k = grad_weight.narrow(0, 0, unit_size_x).narrow( 1, unit_size_y * 3, unit_size_y ) if ctx.needs_input_grad[5]: grad_bias = grad_output.sum(0).squeeze(0) return ( grad_input, grad_weight_r, grad_weight_i, grad_weight_j, grad_weight_k, grad_bias, )
[docs] def quaternion_linear_op(input, r_weight, i_weight, j_weight, k_weight, bias): """ Applies a quaternion linear transformation to the incoming data: It is important to notice that the forward phase of a QNN is defined as W * Inputs (with * equal to the Hamilton product). The constructed cat_kernels_4_quaternion is a modified version of the quaternion representation so when we do,W) it's equivalent to W * Inputs. Arguments --------- input : torch.Tensor Quaternion input tensor to be transformed. r_weight : torch.Parameter Real part of the quaternion weight matrix of this layer. i_weight : torch.Parameter First imaginary part of the quaternion weight matrix of this layer. j_weight : torch.Parameter Second imaginary part of the quaternion weight matrix of this layer. k_weight : torch.Parameter Third imaginary part of the quaternion weight matrix of this layer. bias : torch.Parameter Returns ------- The linearly transformed quaternions """ cat_kernels_4_r = [r_weight, -i_weight, -j_weight, -k_weight], dim=0 ) cat_kernels_4_i = [i_weight, r_weight, -k_weight, j_weight], dim=0 ) cat_kernels_4_j = [j_weight, k_weight, r_weight, -i_weight], dim=0 ) cat_kernels_4_k = [k_weight, -j_weight, i_weight, r_weight], dim=0 ) cat_kernels_4_quaternion = [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1, ) # If the input is already [batch*time, N] if input.dim() == 2: if bias.requires_grad: return torch.addmm(bias, input, cat_kernels_4_quaternion) else: return, cat_kernels_4_quaternion) else: output = torch.matmul(input, cat_kernels_4_quaternion) if bias.requires_grad: return output + bias else: return output
[docs] def quaternion_linear_rotation_op( input, r_weight, i_weight, j_weight, k_weight, bias, scale, zero_kernel ): """ Applies a quaternion rotation transformation to the incoming data: The rotation W*x*W^t can be replaced by R*x following: Works for unitary and non-unitary weights (they will be normalized). The initial size of the input must be a multiple of 4 with the real part equal to zero. Rotations only affect the vector part of a quaternion. Arguments --------- input : torch.Tensor Quaternion input tensor to be transformed. r_weight : torch.Parameter Real part of the quaternion weight matrix of this layer. i_weight : torch.Parameter First imaginary part of the quaternion weight matrix of this layer. j_weight : torch.Parameter Second imaginary part of the quaternion weight matrix of this layer. k_weight : torch.Parameter Third imaginary part of the quaternion weight matrix of this layer. bias : torch.Parameter scale : torch.Parameter In the context of a spinor neural network, multiple rotations of the input vector x are performed and summed. Hence, the norm of the output vector always increases with the number of layers, making the neural network instable with deep configurations. The scale parameters are learnable parameters that acts like gates by multiplying the output vector with a small trainable parameter. zero_kernel : torch.Parameter The zero kernel is simply a tensor of zeros with require grad = False. Its shape is equivalent to a quaternion component shape. In fact, it is only needed to make the dimensions match when using the rotation matrix : Returns ------- The linearly rotated quaternions """ # First we normalise the quaternion weights. Only unit quaternions are # valid rotations. square_r = r_weight * r_weight square_i = i_weight * i_weight square_j = j_weight * j_weight square_k = k_weight * k_weight norm = torch.sqrt(square_r + square_i + square_j + square_k) + 0.0001 r_n_weight = r_weight / norm i_n_weight = i_weight / norm j_n_weight = j_weight / norm k_n_weight = k_weight / norm # See for # the rest of the equations. norm_factor = 2.0 square_i = norm_factor * (i_n_weight * i_n_weight) square_j = norm_factor * (j_n_weight * j_n_weight) square_k = norm_factor * (k_n_weight * k_n_weight) ri = norm_factor * r_n_weight * i_n_weight rj = norm_factor * r_n_weight * j_n_weight rk = norm_factor * r_n_weight * k_n_weight ij = norm_factor * i_n_weight * j_n_weight ik = norm_factor * i_n_weight * k_n_weight jk = norm_factor * j_n_weight * k_n_weight if scale.requires_grad: rot_kernel_1 = [ zero_kernel, scale * (1.0 - (square_j + square_k)), scale * (ij - rk), scale * (ik + rj), ], dim=1, ) rot_kernel_2 = [ zero_kernel, scale * (ij + rk), scale * (1.0 - (square_i + square_k)), scale * (jk - ri), ], dim=1, ) rot_kernel_3 = [ zero_kernel, scale * (ik - rj), scale * (jk + ri), scale * (1.0 - (square_i + square_j)), ], dim=1, ) else: rot_kernel_1 = [zero_kernel, (1.0 - (square_j + square_k)), (ij - rk), (ik + rj)], dim=1, ) rot_kernel_2 = [zero_kernel, (ij + rk), (1.0 - (square_i + square_k)), (jk - ri)], dim=1, ) rot_kernel_3 = [zero_kernel, (ik - rj), (jk + ri), (1.0 - (square_i + square_j))], dim=1, ) zero_kernel2 = [zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1 ) global_rot_kernel = [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0 ) if input.dim() == 2: if bias.requires_grad: return torch.addmm(bias, input, global_rot_kernel) else: return, global_rot_kernel) else: output = torch.matmul(input, global_rot_kernel) if bias.requires_grad: return output + bias else: return output
[docs] def quaternion_conv_rotation_op( input, r_weight, i_weight, j_weight, k_weight, bias, scale, zero_kernel, stride: int, padding: int, groups: int, dilation: int, conv1d: bool, ): """ Applies a quaternion rotation transformation to the incoming data: The rotation W*x*W^t can be replaced by R*x following: Works for unitary and non-unitary weights (they will be normalized). The initial size of the input must be a multiple of 4 with the real part equal to zero. Rotations only affect the vector part of a quaternion. Arguments --------- input : torch.Tensor Quaternion input tensor to be transformed. r_weight : torch.Parameter Real part of the quaternion weight matrix of this layer. i_weight : torch.Parameter First imaginary part of the quaternion weight matrix of this layer. j_weight : torch.Parameter Second imaginary part of the quaternion weight matrix of this layer. k_weight : torch.Parameter Third imaginary part of the quaternion weight matrix of this layer. bias : torch.Parameter scale : torch.Parameter In the context of a spinor neural network, multiple rotations of the input vector x are performed and summed. Hence, the norm of the output vector always increases with the number of layers, making the neural network instable with deep configurations. The scale parameters are learnable parameters that acts like gates by multiplying the output vector with a small trainable parameter. zero_kernel : torch.Parameter The zero kernel is simply a tensor of zeros with require grad = False. Its shape is equivalent to a quaternion component shape. In fact, it is only needed to make the dimensions match when using the rotation matrix : stride : int Stride factor of the convolutional filters. padding : int Amount of padding. See torch.nn documentation for more information. groups : int This option specifies the convolutional groups. See torch.nn documentation for more information. dilation : int Dilation factor of the convolutional filters. conv1d : bool If true, a 1D convolution operation will be applied. Otherwise, a 2D convolution is called. Returns ------- The rotated quaternion inputs """ square_r = r_weight * r_weight square_i = i_weight * i_weight square_j = j_weight * j_weight square_k = k_weight * k_weight norm = torch.sqrt(square_r + square_i + square_j + square_k + 0.0001) r_n_weight = r_weight / norm i_n_weight = i_weight / norm j_n_weight = j_weight / norm k_n_weight = k_weight / norm norm_factor = 2.0 square_i = norm_factor * (i_n_weight * i_n_weight) square_j = norm_factor * (j_n_weight * j_n_weight) square_k = norm_factor * (k_n_weight * k_n_weight) ri = norm_factor * r_n_weight * i_n_weight rj = norm_factor * r_n_weight * j_n_weight rk = norm_factor * r_n_weight * k_n_weight ij = norm_factor * i_n_weight * j_n_weight ik = norm_factor * i_n_weight * k_n_weight jk = norm_factor * j_n_weight * k_n_weight if scale.requires_grad: rot_kernel_1 = [ zero_kernel, scale * (1.0 - (square_j + square_k)), scale * (ij - rk), scale * (ik + rj), ], dim=1, ) rot_kernel_2 = [ zero_kernel, scale * (ij + rk), scale * (1.0 - (square_i + square_k)), scale * (jk - ri), ], dim=1, ) rot_kernel_3 = [ zero_kernel, scale * (ik - rj), scale * (jk + ri), scale * (1.0 - (square_i + square_j)), ], dim=1, ) else: rot_kernel_1 = [zero_kernel, (1.0 - (square_j + square_k)), (ij - rk), (ik + rj)], dim=1, ) rot_kernel_2 = [zero_kernel, (ij + rk), (1.0 - (square_i + square_k)), (jk - ri)], dim=1, ) rot_kernel_3 = [zero_kernel, (ik - rj), (jk + ri), (1.0 - (square_i + square_j))], dim=1, ) zero_kernel2 = [zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1 ) global_rot_kernel = [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0 ) if conv1d: return F.conv1d( input=input, weight=global_rot_kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, ) else: return F.conv2d( input=input, weight=global_rot_kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, )
[docs] def quaternion_conv_op( input, r_weight, i_weight, j_weight, k_weight, bias, stride: int, padding: int, groups: int, dilation: int, conv1d: bool, ): """ Applies a quaternion convolution transformation to the incoming data: It is important to notice that the forward phase of a QCNN is defined as W * Inputs (with * equal to the Hamilton product). The constructed cat_kernels_4_quaternion is a modified version of the quaternion representation so when we do,W) it's equivalent to W * Inputs. Arguments --------- input : torch.Tensor Quaternion input tensor to be transformed. r_weight : torch.Parameter Real part of the quaternion weight matrix of this layer. i_weight : torch.Parameter First imaginary part of the quaternion weight matrix of this layer. j_weight : torch.Parameter Second imaginary part of the quaternion weight matrix of this layer. k_weight : torch.Parameter Third imaginary part of the quaternion weight matrix of this layer. bias : torch.Parameter stride : int Stride factor of the convolutional filters. padding : int Amount of padding. See torch.nn documentation for more information. groups : int This option specifies the convolutional groups. See torch.nn documentation for more information. dilation : int Dilation factor of the convolutional filters. conv1d : bool If true, a 1D convolution operation will be applied. Otherwise, a 2D convolution is called. Returns ------- The convolved quaternion inputs """ cat_kernels_4_r = [r_weight, -i_weight, -j_weight, -k_weight], dim=1 ) cat_kernels_4_i = [i_weight, r_weight, -k_weight, j_weight], dim=1 ) cat_kernels_4_j = [j_weight, k_weight, r_weight, -i_weight], dim=1 ) cat_kernels_4_k = [k_weight, -j_weight, i_weight, r_weight], dim=1 ) cat_kernels_4_quaternion = [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0, ) if conv1d: return F.conv1d( input=input, weight=cat_kernels_4_quaternion, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, ) else: return F.conv2d( input=input, weight=cat_kernels_4_quaternion, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, )
[docs] def quaternion_init( in_features, out_features, kernel_size=None, criterion="glorot" ): """Returns a matrix of quaternion numbers initialized with the method described in "Quaternion Recurrent Neural Network " - Parcollet T. Arguments --------- in_features : int Number of real values of the input layer (quaternion // 4). out_features : int Number of real values of the output layer (quaternion // 4). kernel_size : int Kernel_size for convolutional layers (ex: (3,3)). criterion : str (glorot, he) Returns ------- Matrix of initialized quaternion numbers """ # We set the numpy seed equal to the torch seed for reproducibility # Indeed we use numpy and scipy here. We need % (2**31-1) or, if the # seed hasn't been set by the used in the YAML file, torch will generate # a double that would be to big for numpy. np.random.seed(seed=torch.initial_seed() % (2**31 - 1)) if kernel_size is not None: receptive_field = fan_in = in_features * receptive_field fan_out = out_features * receptive_field else: fan_in = in_features fan_out = out_features if criterion == "glorot": s = 1.0 / np.sqrt(2 * (fan_in + fan_out)) else: s = 1.0 / np.sqrt(2 * fan_in) # Generating randoms and purely imaginary quaternions : if kernel_size is None: kernel_shape = (in_features, out_features) else: if type(kernel_size) is int: kernel_shape = (out_features, in_features) + tuple((kernel_size,)) else: kernel_shape = (out_features, in_features) + (*kernel_size,) modulus = torch.from_numpy(chi.rvs(4, loc=0, scale=s, size=kernel_shape)) number_of_weights = v_i = torch.FloatTensor(number_of_weights).uniform_(-1, 1) v_j = torch.FloatTensor(number_of_weights).uniform_(-1, 1) v_k = torch.FloatTensor(number_of_weights).uniform_(-1, 1) # Purely imaginary quaternions unitary for i in range(0, number_of_weights): norm = torch.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001 v_i[i] /= norm v_j[i] /= norm v_k[i] /= norm v_i = v_i.reshape(kernel_shape) v_j = v_j.reshape(kernel_shape) v_k = v_k.reshape(kernel_shape) phase = torch.rand(kernel_shape).uniform_(-math.pi, math.pi) weight_r = modulus * torch.cos(phase) weight_i = modulus * v_i * torch.sin(phase) weight_j = modulus * v_j * torch.sin(phase) weight_k = modulus * v_k * torch.sin(phase) return (weight_r, weight_i, weight_j, weight_k)
[docs] def unitary_init(in_features, out_features, kernel_size=None, criterion="he"): """Returns a matrix of unitary quaternion numbers. Arguments --------- in_features : int Number of real values of the input layer (quaternion // 4). out_features : int Number of real values of the output layer (quaternion // 4). kernel_size : int Kernel_size for convolutional layers (ex: (3,3)). criterion : str (glorot, he) Returns ------- Matrix of unitary quaternion numbers. """ if kernel_size is None: kernel_shape = (in_features, out_features) else: if type(kernel_size) is int: kernel_shape = (out_features, in_features) + tuple((kernel_size,)) else: kernel_shape = (out_features, in_features) + (*kernel_size,) number_of_weights = v_r = torch.FloatTensor(number_of_weights).uniform_(-1, 1) v_i = torch.FloatTensor(number_of_weights).uniform_(-1, 1) v_j = torch.FloatTensor(number_of_weights).uniform_(-1, 1) v_k = torch.FloatTensor(number_of_weights).uniform_(-1, 1) # Unitary quaternion for i in range(0, number_of_weights): norm = ( torch.sqrt(v_r[i] ** 2 + v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001 ) v_r[i] /= norm v_i[i] /= norm v_j[i] /= norm v_k[i] /= norm v_r = v_r.reshape(kernel_shape) v_i = v_i.reshape(kernel_shape) v_j = v_j.reshape(kernel_shape) v_k = v_k.reshape(kernel_shape) return (v_r, v_i, v_j, v_k)
[docs] def affect_init( r_weight, i_weight, j_weight, k_weight, init_func, init_criterion ): """Applies the weight initialization function given to the parameters. Arguments --------- r_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) i_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) j_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) k_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) init_func : function (unitary_init, quaternion_init) init_criterion : str (glorot, he) """ r, i, j, k = init_func( r_weight.size(0), r_weight.size(1), None, init_criterion ) = r.type_as( = i.type_as( = j.type_as( = k.type_as(
[docs] def affect_conv_init( r_weight, i_weight, j_weight, k_weight, kernel_size, init_func, init_criterion, ): """Applies the weight initialization function given to the parameters. This is specifically written for convolutional layers. Arguments --------- r_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) i_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) j_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) k_weight : torch.Parameters (nb_quaternion_in, nb_quaternion_out) kernel_size : int Kernel size. init_func : function (unitary_init, quaternion_init) init_criterion : str (glorot, he) """ in_channels = r_weight.size(1) out_channels = r_weight.size(0) r, i, j, k = init_func( in_channels, out_channels, kernel_size=kernel_size, criterion=init_criterion, ) = r.type_as( = i.type_as( = j.type_as( = k.type_as(
[docs] def check_quaternion_input(input_shape): """Check the quaternion-valued shape for a linear layer. Arguments --------- input_shape : tuple Expected shape of the input. """ if len(input_shape) not in {1, 2, 3}: raise Exception( "Quaternion linear accepts only input of dimension 2 or 3." " input.dim = " + str(input.dim()) ) nb_hidden = input_shape[-1] if nb_hidden % 4 != 0: raise Exception( "Quaternion torch.Tensors must have dimensions divisible by 4." " input.size()[1] = " + str(nb_hidden) )
[docs] def renorm_quaternion_weights_inplace( r_weight, i_weight, j_weight, k_weight, max_norm ): """Renorms the magnitude of the quaternion-valued weights. Arguments --------- r_weight : torch.Parameter i_weight : torch.Parameter j_weight : torch.Parameter k_weight : torch.Parameter max_norm : float The maximum norm of the magnitude of the quaternion weights """ weight_magnitude = torch.sqrt(**2 +**2 +**2 +**2 ) renormed_weight_magnitude = torch.renorm( weight_magnitude, p=2, dim=0, maxnorm=max_norm ) factor = renormed_weight_magnitude / weight_magnitude *= factor *= factor *= factor *= factor