"""Library implementing quaternion-valued max and average pooling layers.
Authors
* Drew Wagner 2024
"""
import torch
import speechbrain as sb
[docs]
class QPooling2d(sb.nnet.pooling.Pooling2d):
"""This class implements the quaternion average pooling and max pooling
by magnitude as described in: "Geometric methods of perceptual organisation for
computer vision", Altamirano G.
Arguments
---------
pool_type : str
It is the type of pooling function to use ('avg','max').
kernel_size : int
It is the kernel size that defines the pooling dimension.
For instance, kernel size=3,3 performs a 2D Pooling with a 3x3 kernel.
pool_axis : tuple
It is a list containing the axis that will be considered
during pooling.
ceil_mode : bool
When True, will use ceil instead of floor to compute the output shape.
padding : int
It is the number of padding elements to apply.
dilation : int
Controls the dilation factor of pooling.
stride : int
It is the stride size.
Example
-------
>>> pool = QPooling2d('max',(5,3))
>>> inputs = torch.rand(10, 15, 12)
>>> output=pool(inputs)
>>> output.shape
torch.Size([10, 3, 4])
"""
def __init__(
self,
pool_type,
kernel_size,
pool_axis=(1, 2),
ceil_mode=False,
padding=0,
dilation=1,
stride=None,
):
super().__init__(
pool_type,
kernel_size,
pool_axis=pool_axis,
ceil_mode=ceil_mode,
padding=padding,
dilation=dilation,
stride=stride,
)
if self.pool_type == "max":
self.pool_layer.return_indices = True
[docs]
def forward(self, x):
"""Performs 2d pooling to the input tensor.
Arguments
---------
x : torch.Tensor
It represents a tensor for a mini-batch.
Returns
-------
The pooled tensor.
"""
x_r, x_i, x_j, x_k = torch.chunk(x, 4, dim=-1)
if self.pool_type == "avg":
# Perform average pooling over each of the components of the quaternion
x_r = super().forward(x_r)
x_i = super().forward(x_i)
x_j = super().forward(x_j)
x_k = super().forward(x_k)
elif self.pool_type == "max":
# Compute the magnitude of the quaternion
m = x_r**2 + x_i**2 + x_j**2 + x_k**2
# Add extra two dimension at the last two, and then swap the pool_axis to them
# Example: pool_axis=[1,2]
# [a,b,c,d] => [a,b,c,d,1,1]
# [a,b,c,d,1,1] => [a,1,c,d,b,1]
# [a,1,c,d,b,1] => [a,1,1,d,b,c]
# [a,1,1,d,b,c] => [a,d,b,c]
m = (
m.unsqueeze(-1)
.unsqueeze(-1)
.transpose(-2, self.pool_axis[0])
.transpose(-1, self.pool_axis[1])
.squeeze(self.pool_axis[1])
.squeeze(self.pool_axis[0])
)
# Perform max pooling of the magnitude, returning only the indices
_, idx = self.pool_layer(m)
idx = (
idx.unsqueeze(self.pool_axis[0])
.unsqueeze(self.pool_axis[1])
.transpose(-2, self.pool_axis[0])
.transpose(-1, self.pool_axis[1])
.squeeze(-1)
.squeeze(-1)
)
idx_flat = idx.flatten()
# Select the r, i, j & k components of the quaternion with the max magnitude
x_r = x_r.flatten()[idx_flat].reshape(idx.shape)
x_i = x_i.flatten()[idx_flat].reshape(idx.shape)
x_j = x_j.flatten()[idx_flat].reshape(idx.shape)
x_k = x_k.flatten()[idx_flat].reshape(idx.shape)
return torch.concat((x_r, x_i, x_j, x_k), dim=-1)