speechbrain.nnet.quantisers module

Gumbel Softmax implementation with multiple groups possible.

Authors
  • Rudolf A. Braun 2022

Summary

Classes:

GumbelVectorQuantizer

Vector quantization using gumbel softmax.

RandomProjectionQuantizer

Vector quantization using a projection and a randomly initialised codebook this is useful for models like BEST-RQ for instance.

Reference

class speechbrain.nnet.quantisers.GumbelVectorQuantizer(input_dim, num_vars, temp_tuple, groups, vq_dim)[source]

Bases: Module

Vector quantization using gumbel softmax. Copied from fairseq implementation. :param input_dim: Input dimension (channels). :type input_dim: int :param num_vars: Number of quantized vectors per group. :type num_vars: int :param temp_tuple: Temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor). :type temp_tuple: float :param groups: Number of groups for vector quantization. :type groups: int :param vq_dim: Dimensionality of the resulting quantized vector. :type vq_dim: int

Example

>>> quantiser = GumbelVectorQuantizer(
...     128,
...     100,
...     (
...         2.0,
...         0.25,
...         0.999995,
...     ),
...     2,
...     50,
... )
>>> inputs = torch.rand(10, 12, 128)
>>> output = quantiser(inputs)
>>> output["x"].shape
torch.Size([10, 12, 50])
update_temp(steps)[source]

Update the temperature given the current step

forward(x)[source]

Forward the latent vector to obtain a quantised output

class speechbrain.nnet.quantisers.RandomProjectionQuantizer(input_dim, cb_dim, cb_vocab)[source]

Bases: Module

Vector quantization using a projection and a randomly initialised codebook this is useful for models like BEST-RQ for instance.

The output is the indices of the closest code in the codebook for each time step of the input.

ref: https://arxiv.org/pdf/2202.01855

Parameters:
  • input_dim (int) – Input dimension (channels).

  • cb_dim (int) – Size of each code in the codebook.

  • cb_vocab (int) – Number of codes in the codebook

Example

>>> quantiser = RandomProjectionQuantizer(16, 16, 32)
>>> inputs = torch.rand(10, 12, 16)
>>> output = quantiser(inputs)
>>> output.shape
torch.Size([10, 12])
forward(x)[source]

Forward the latent vector to obtain a quantised output