speechbrain.lobes.models.bsq moduleο
Binary spherical quantizer.
- Authors
Luca Della Libera 2025
Summaryο
Classes:
Binary spherical quantizer. |
__all__: BinarySphericalQuantizer
Referenceο
- class speechbrain.lobes.models.bsq.BinarySphericalQuantizer(code_dim: int, entropy_loss_weight: float = 0.1, diversity_gamma: float = 1.0)[source]ο
Bases:
ModuleBinary spherical quantizer.
This module implements a binary quantizer over the unit hypersphere. Given a continuous input vector x β R^{D}, it:
Projects x onto the unit sphere.
Quantizes each dimension to {-1/sqrt(D), +1/sqrt(D)} based on its sign.
Interprets the resulting sign pattern as a binary code index.
Computes an auxiliary entropy/diversity loss to encourage confident assignments and uniform codebook usage.
- Parameters:
code_dim (int) β Dimensionality of the code / number of bits per code vector. The codebook size is 2 ** code_dim.
entropy_loss_weight (float, optional) β Weight for the entropy-based auxiliary loss term.
diversity_gamma (float, optional) β Coefficient for the codebook entropy term in the auxiliary loss. Larger values encourage more uniform usage of all codes.
Example
>>> import torch >>> code_dim = 13 >>> x = torch.randn(2, 50, code_dim) >>> quantizer = BinarySphericalQuantizer(code_dim) >>> quant, indices, aux_loss = quantizer(x)
- bits_to_codes(bits: Tensor) Tensor[source]ο
Convert {0, 1} bits to {-1, +1} codes.
- Parameters:
bits (torch.Tensor) β Tensor of bits in {0, 1} with shape [β¦, code_dim].
- Returns:
Tensor of codes in {-1, +1} with the same shape as
bits.- Return type:
- forward(x: Tensor, inv_temperature: float = 100.0) Tuple[Tensor, Tensor, Tensor][source]ο
Quantize continuous vectors on the binary sphere.
- Parameters:
x (torch.Tensor) β Input tensor of shape [β¦, code_dim]. The last dimension must match
self.code_dim. It is L2-normalized internally.inv_temperature (float, optional) β Inverse temperature for the softmax over codebook distances used to compute the entropy-based auxiliary loss.
- Returns:
A tuple (quantized, indices, aux_loss) where: - quantized: torch.Tensor
Quantized version of the input with the same shape as
x, lying on the unit sphere with values approximately in {-1, +1}.- indices: torch.Tensor
Integer code indices of shape [β¦], obtained by interpreting the sign pattern of each vector as a binary code.
- aux_loss: torch.Tensor
Scalar auxiliary loss combining per-sample entropy and codebook-diversity regularization, scaled by
entropy_loss_weight.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]