Source code for speechbrain.nnet.quaternion_networks.q_linear

"""Library implementing quaternion-valued linear transformation.

Authors
 * Titouan Parcollet 2020
"""

import torch
import logging
from speechbrain.nnet.quaternion_networks.q_ops import (
    affect_init,
    unitary_init,
    quaternion_init,
    quaternion_linear_op,
    check_quaternion_input,
    quaternion_linear_rotation_op,
    QuaternionLinearCustomBackward,
)

logger = logging.getLogger(__name__)


[docs] class QLinear(torch.nn.Module): """This function implements a fully connected quaternion-valued linear layer: y = Wx + b. y, W, x and b are thus quaternion numbers. A quaternion number is written as: r + xi + yj + zk. A tensor of quaternion numbers x = [batch, 32] can be understood as [batch, 0:7] = R, [batch, 8:15] = Xi, [batch, 16:23] = Yi, and [batch, 24:31] = Xi. Thus the features dimension is cut in four (must be divisible by 4). Arguments --------- n_neurons : int It is the number of output neurons (i.e, the dimensionality of the output). Please note that these are quaternion-valued neurons. If 256 neurons are specified, the output dimension will be 1024. input_shape : tuple Expected size of the input. bias : bool If True, the additive bias b is adopted. init_criterion : str , optional (glorot, he). This parameter controls the initialization criterion of the weights. It is combined with weights_init to build the initialization method of the quaternion-valued weights (default "glorot"). weight_init : str, optional (quaternion, unitary). This parameter defines the initialization procedure of the quaternion-valued weights. "quaternion" will generate quaternion-valued weights following the init_criterion and the quaternion polar form. "unitary" will normalize the weights to lie on the unit circle (default "quaternion"). More details in: "Quaternion recurrent neural networks", Parcollet T. autograd : bool, optional When True, the default PyTorch autograd will be used. When False, a custom backpropagation will be used, reducing by a factor 3 to 4 the memory consumption. It is also 2x slower. This only works with spinor = False (default True). spinor : bool, optional When True, the layer will be turned into a spinor layer. More precisely W*x will be turned into W*x*W-1. The input x will be rotated by W such as in a spinor neural network. However, x MUST be a quaternion with the real part equal to zero. (0 + xi + yj + zk). Indeed, the rotation operation only acts on the vector part. Note that W will always be normalized before the rotation to ensure the quaternion algebra (default False). More details in: "Quaternion neural networks", Parcollet T. vector_scale : bool, optional The vector_scale is only used when spinor = True. In the context of a spinor neural network, multiple rotations of the input vector x are performed and summed. Hence, the norm of the output vector always increases with the number of layers, making the neural network instable with deep configurations. The vector_scale parameters are learnable parameters that acts like gates by multiplying the output vector with a small trainable parameter (default False). Example ------- >>> inputs = torch.rand(10, 50, 40) >>> lin = QLinear(n_neurons=100, input_shape=inputs.shape, weight_init='unitary') >>> output = lin(inputs) >>> output.shape torch.Size([10, 50, 400]) """ def __init__( self, n_neurons, input_shape, bias=True, init_criterion="glorot", weight_init="quaternion", autograd=True, spinor=False, vector_scale=False, ): super().__init__() self.n_neurons = n_neurons self.bias = bias self.init_criterion = init_criterion self.weight_init = weight_init self.autograd = autograd self.spinor = spinor self.vector_scale = vector_scale # When initialising with speechbrain the input_shape is an integer ! # we need to transform it into a list it works with all the question ops if isinstance(input_shape, int): input_shape = [1, input_shape] # Check the quaternion_valued form of the input check_quaternion_input(input_shape) # Computing the quaternion dimensionality of the input self.in_features = input_shape[-1] // 4 self.out_features = self.n_neurons # Defining the weights self.r_weight = torch.nn.Parameter( torch.Tensor(self.in_features, self.out_features) ) self.i_weight = torch.nn.Parameter( torch.Tensor(self.in_features, self.out_features) ) self.j_weight = torch.nn.Parameter( torch.Tensor(self.in_features, self.out_features) ) self.k_weight = torch.nn.Parameter( torch.Tensor(self.in_features, self.out_features) ) # Spinor specific parameters if self.spinor: self.zero_kernel = torch.nn.Parameter( torch.zeros(self.r_weight.shape), requires_grad=False ) else: self.zero_kernel = torch.Tensor(self.r_weight.shape).requires_grad_( False ) if self.spinor and self.vector_scale: self.scale_param = torch.nn.Parameter( torch.Tensor(self.in_features, self.out_features) ) torch.nn.init.xavier_uniform_(self.scale_param.data) else: self.scale_param = torch.Tensor( self.in_features, self.out_features ).requires_grad_(False) if self.bias: self.b = torch.nn.Parameter(torch.Tensor(4 * n_neurons)) self.b.data.fill_(0) else: self.b = torch.Tensor(4 * n_neurons).requires_grad_(False) # Managing the weight initialization and bias self.winit = {"quaternion": quaternion_init, "unitary": unitary_init}[ self.weight_init ] # Initialise the weights affect_init( self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.winit, init_criterion, )
[docs] @torch.jit.ignore def forward(self, x): """Returns the linear transformation of input tensor. Arguments --------- x : torch.Tensor Input to transform linearly. """ if self.autograd: if self.spinor: out = quaternion_linear_rotation_op( x, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.b, self.scale_param, self.zero_kernel, ) else: out = quaternion_linear_op( x, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.b, ) else: # The custom backward needs an input with 2D at most! input_dim = x.dim() if input_dim == 3: batch, time, fea = x.size() x = x.view(batch * time, fea) out = QuaternionLinearCustomBackward.apply( x, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.b, ) if input_dim == 3: out = out.view(batch, time, out.size(-1)) return out