"""
Gumbel Softmax implementation with multiple groups possible.
Authors
* Rudolf A. Braun 2022
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class GumbelVectorQuantizer(nn.Module):
"""Vector quantization using gumbel softmax. Copied from fairseq implementation.
Arguments
---------
input_dim: int
Input dimension (channels).
num_vars: int
Number of quantized vectors per group.
temp_tuple: float
Temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor).
groups: int
Number of groups for vector quantization.
vq_dim: int
Dimensionality of the resulting quantized vector.
Example
-------
>>> quantiser = GumbelVectorQuantizer(128, 100, (2.0, 0.25, 0.999995,), 2, 50 )
>>> inputs = torch.rand(10, 12, 128)
>>> output = quantiser(inputs)
>>> output["x"].shape
torch.Size([10, 12, 50])
"""
def __init__(self, input_dim, num_vars, temp_tuple, groups, vq_dim):
super().__init__()
self.groups = groups
self.input_dim = input_dim
self.num_vars = num_vars
self.vq_dim = vq_dim
assert (
vq_dim % groups == 0
), f"dim {vq_dim} must be divisible by groups {groups} for concatenation"
var_dim = vq_dim // groups
self.vars = nn.Parameter(
torch.FloatTensor(1, groups * num_vars, var_dim)
)
nn.init.uniform_(self.vars)
self.weight_proj = nn.Linear(self.input_dim, groups * num_vars)
nn.init.normal_(self.weight_proj.weight, mean=0, std=1)
nn.init.zeros_(self.weight_proj.bias)
assert len(temp_tuple) == 3, temp_tuple
self.max_temp, self.min_temp, self.temp_decay = temp_tuple
self.curr_temp = self.max_temp
self.max_ent = nn.Parameter(
torch.log(torch.tensor(float(self.num_vars * self.groups))),
requires_grad=False,
)
[docs]
def update_temp(self, steps):
""" Update the temperature given the current step """
self.curr_temp = max(
self.max_temp * self.temp_decay ** steps, self.min_temp
)
[docs]
def forward(self, x):
""" Forward the latent vector to obtain a quantised output """
result = {
"num_vars": self.num_vars * self.groups,
"temp": self.curr_temp,
}
bsz, tsz, fsz = x.shape
x = x.reshape(-1, fsz)
x = self.weight_proj(x)
x = x.view(bsz * tsz * self.groups, -1)
_, k = x.max(-1)
hard_x = (
x.new_zeros(*x.shape)
.scatter_(-1, k.view(-1, 1), 1.0)
.view(bsz * tsz, self.groups, -1)
)
hard_probs = torch.mean(hard_x.float(), dim=0)
result["code_perplexity"] = torch.exp(
-torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
).sum()
avg_probs = torch.softmax(
x.view(bsz * tsz, self.groups, -1).float(), dim=-1
).mean(dim=0)
result["prob_perplex"] = torch.exp(
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
).sum()
result["temp"] = self.curr_temp
if self.training:
x = F.gumbel_softmax(
x.float(), tau=self.curr_temp, hard=True
).type_as(x)
else:
x = hard_x
x = x.view(bsz * tsz, -1)
vars = self.vars
x = x.unsqueeze(-1) * vars
x = x.view(bsz * tsz, self.groups, self.num_vars, -1)
x = x.sum(-2)
x = x.view(bsz, tsz, -1)
result["x"] = x
return result