speechbrain.nnet.complex_networks.c_linear module

Library implementing complex-valued linear transformation.

Authors
  • Titouan Parcollet 2020

Summary

Classes:

CLinear

This function implements a fully connected complex-valued linear layer: y = Wx + b.

Reference

class speechbrain.nnet.complex_networks.c_linear.CLinear(n_neurons, input_shape, bias=True, init_criterion='glorot', weight_init='complex')[source]

Bases: torch.nn.modules.module.Module

This function implements a fully connected complex-valued linear layer: y = Wx + b. y, W, x and b are thus complex numbers. A complex number is written as: r + xi. A tensor of complex numbers x = [batch, 32] can be understood as [batch, 0:15] = R and [batch, 16:31] = Xi. Thus the features dimension is cut in half (must be divisible by 2).

Parameters
  • n_neurons (int) – It is the number of output neurons (i.e, the dimensionality of the output). Please note that these are complex-valued neurons. If 256 neurons are specified, the output dimension will be 512.

  • input_shape (tuple) – Expected size of the input.

  • bias (bool) – if True, the additive bias b is adopted.

  • init_criterion (str , optional) – (glorot, he). This parameter controls the initialization criterion of the weights. It is combined with weights_init to build the initialization method of the complex-valued weights (default “glorot”).

  • weight_init (str, optional) – (complex, unitary). This parameter defines the initialization procedure of the complex-valued weights (default “complex”). “complex” will generate random complex-valued weights following the init_criterion and the complex polar form. “unitary” will normalize the weights to lie on the unit circle. More details in: “Deep Complex Networks”, Trabelsi C. et al.

Example

>>> inputs = torch.rand(10, 50, 40)
>>> lin = CLinear(n_neurons=100, input_shape=inputs.shape)
>>> output = lin(inputs)
>>> output.shape
torch.Size([10, 50, 200])
forward(x)[source]

Returns the linear transformation of input tensor.

Parameters

x (torch.Tensor) – Input to transform linearly.

training: bool