speechbrain.lobes.models.bsq module

Binary spherical quantizer.

Authors
  • Luca Della Libera 2025

Summary

Classes:

BinarySphericalQuantizer

Binary spherical quantizer.

__all__: BinarySphericalQuantizer

Reference

class speechbrain.lobes.models.bsq.BinarySphericalQuantizer(code_dim: int, entropy_loss_weight: float = 0.1, diversity_gamma: float = 1.0)[source]

Bases: Module

Binary spherical quantizer.

This module implements a binary quantizer over the unit hypersphere. Given a continuous input vector x ∈ R^{D}, it:

  1. Projects x onto the unit sphere.

  2. Quantizes each dimension to {-1/sqrt(D), +1/sqrt(D)} based on its sign.

  3. Interprets the resulting sign pattern as a binary code index.

  4. Computes an auxiliary entropy/diversity loss to encourage confident assignments and uniform codebook usage.

Parameters:
  • code_dim (int) – Dimensionality of the code / number of bits per code vector. The codebook size is 2 ** code_dim.

  • entropy_loss_weight (float, optional) – Weight for the entropy-based auxiliary loss term.

  • diversity_gamma (float, optional) – Coefficient for the codebook entropy term in the auxiliary loss. Larger values encourage more uniform usage of all codes.

Example

>>> import torch
>>> code_dim = 13
>>> x = torch.randn(2, 50, code_dim)
>>> quantizer = BinarySphericalQuantizer(code_dim)
>>> quant, indices, aux_loss = quantizer(x)
bits_to_codes(bits: Tensor) Tensor[source]

Convert {0, 1} bits to {-1, +1} codes.

Parameters:

bits (torch.Tensor) – Tensor of bits in {0, 1} with shape […, code_dim].

Returns:

Tensor of codes in {-1, +1} with the same shape as bits.

Return type:

torch.Tensor

forward(x: Tensor, inv_temperature: float = 100.0) Tuple[Tensor, Tensor, Tensor][source]

Quantize continuous vectors on the binary sphere.

Parameters:
  • x (torch.Tensor) – Input tensor of shape […, code_dim]. The last dimension must match self.code_dim. It is L2-normalized internally.

  • inv_temperature (float, optional) – Inverse temperature for the softmax over codebook distances used to compute the entropy-based auxiliary loss.

Returns:

A tuple (quantized, indices, aux_loss) where: - quantized: torch.Tensor

Quantized version of the input with the same shape as x, lying on the unit sphere with values approximately in {-1, +1}.

  • indices: torch.Tensor

    Integer code indices of shape […], obtained by interpreting the sign pattern of each vector as a binary code.

  • aux_loss: torch.Tensor

    Scalar auxiliary loss combining per-sample entropy and codebook-diversity regularization, scaled by entropy_loss_weight.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]