Source code for speechbrain.nnet.quaternion_networks.q_pooling

"""Library implementing quaternion-valued max and average pooling layers.

Authors
 * Drew Wagner 2024
"""

import torch

import speechbrain as sb


[docs] class QPooling2d(sb.nnet.pooling.Pooling2d): """This class implements the quaternion average pooling and max pooling by magnitude as described in: "Geometric methods of perceptual organisation for computer vision", Altamirano G. Arguments --------- pool_type : str It is the type of pooling function to use ('avg','max'). kernel_size : int It is the kernel size that defines the pooling dimension. For instance, kernel size=3,3 performs a 2D Pooling with a 3x3 kernel. pool_axis : tuple It is a list containing the axis that will be considered during pooling. ceil_mode : bool When True, will use ceil instead of floor to compute the output shape. padding : int It is the number of padding elements to apply. dilation : int Controls the dilation factor of pooling. stride : int It is the stride size. Example ------- >>> pool = QPooling2d('max',(5,3)) >>> inputs = torch.rand(10, 15, 12) >>> output=pool(inputs) >>> output.shape torch.Size([10, 3, 4]) """ def __init__( self, pool_type, kernel_size, pool_axis=(1, 2), ceil_mode=False, padding=0, dilation=1, stride=None, ): super().__init__( pool_type, kernel_size, pool_axis=pool_axis, ceil_mode=ceil_mode, padding=padding, dilation=dilation, stride=stride, ) if self.pool_type == "max": self.pool_layer.return_indices = True
[docs] def forward(self, x): """Performs 2d pooling to the input tensor. Arguments --------- x : torch.Tensor It represents a tensor for a mini-batch. Returns ------- The pooled tensor. """ x_r, x_i, x_j, x_k = torch.chunk(x, 4, dim=-1) if self.pool_type == "avg": # Perform average pooling over each of the components of the quaternion x_r = super().forward(x_r) x_i = super().forward(x_i) x_j = super().forward(x_j) x_k = super().forward(x_k) elif self.pool_type == "max": # Compute the magnitude of the quaternion m = x_r**2 + x_i**2 + x_j**2 + x_k**2 # Add extra two dimension at the last two, and then swap the pool_axis to them # Example: pool_axis=[1,2] # [a,b,c,d] => [a,b,c,d,1,1] # [a,b,c,d,1,1] => [a,1,c,d,b,1] # [a,1,c,d,b,1] => [a,1,1,d,b,c] # [a,1,1,d,b,c] => [a,d,b,c] m = ( m.unsqueeze(-1) .unsqueeze(-1) .transpose(-2, self.pool_axis[0]) .transpose(-1, self.pool_axis[1]) .squeeze(self.pool_axis[1]) .squeeze(self.pool_axis[0]) ) # Perform max pooling of the magnitude, returning only the indices _, idx = self.pool_layer(m) idx = ( idx.unsqueeze(self.pool_axis[0]) .unsqueeze(self.pool_axis[1]) .transpose(-2, self.pool_axis[0]) .transpose(-1, self.pool_axis[1]) .squeeze(-1) .squeeze(-1) ) idx_flat = idx.flatten() # Select the r, i, j & k components of the quaternion with the max magnitude x_r = x_r.flatten()[idx_flat].reshape(idx.shape) x_i = x_i.flatten()[idx_flat].reshape(idx.shape) x_j = x_j.flatten()[idx_flat].reshape(idx.shape) x_k = x_k.flatten()[idx_flat].reshape(idx.shape) return torch.concat((x_r, x_i, x_j, x_k), dim=-1)