speechbrain.lobes.models.PIQ module

This file implements the necessary classes and functions to implement Posthoc Interpretations via Quantization.

Authors * Cem Subakan 2023 * Francesco Paissan 2023

Summary

Classes:

Conv2dEncoder_v2

This class implements a convolutional encoder to extract classification embeddings from logspectra.

ResBlockAudio

This class implements a residual block.

VQEmbedding

Implements VQ Dictionary.

VectorQuantization

This class defines the forward method for vector quantization.

VectorQuantizationStraightThrough

This class defines the forward method for vector quantization.

VectorQuantizedPSI_Audio

This class reconstructs log-power spectrograms from classifier's representations.

Functions:

get_irrelevant_regions

This class returns binary matrix that indicates the irrelevant regions in the VQ-dictionary given the labels array

weights_init

Applies Xavier initialization to network weights.

Reference

speechbrain.lobes.models.PIQ.get_irrelevant_regions(labels, K, num_classes, N_shared=5, stage='TRAIN')[source]

This class returns binary matrix that indicates the irrelevant regions in the VQ-dictionary given the labels array

Parameters:
  • labels (torch.tensor) – 1 dimensional torch.tensor of size [B]

  • K (int) – Number of keys in the dictionary

  • num_classes (int) – Number of possible classes

  • N_shared (int) – Number of shared keys

  • stage (str) – “TRAIN” or else

  • Example

  • --------

  • torch.Tensor([1 (>>> labels =) –

  • 0

  • 2])

  • get_irrelevant_regions(labels (>>> irrelevant_regions =) –

  • 20

  • 3

  • 5)

  • print(irrelevant_regions.shape) (>>>) –

  • torch.Size([3

  • 20])

speechbrain.lobes.models.PIQ.weights_init(m)[source]

Applies Xavier initialization to network weights.

class speechbrain.lobes.models.PIQ.VectorQuantization(*args, **kwargs)[source]

Bases: Function

This class defines the forward method for vector quantization. As VQ is not differentiable, it returns a RuntimeError in case .grad() is called. Refer to VectorQuantizationStraightThrough for a straight_through estimation of the gradient for the VQ operation.

static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]

Applies VQ to vectors input with codebook as VQ dictionary.

Parameters:
  • inputs (torch.Tensor) – Hidden representations to quantize. Expected shape is torch.Size([B, W, H, C]).

  • codebook (torch.Tensor) – VQ-dictionary for quantization. Expected shape of torch.Size([K, C]) with K dictionary elements.

  • labels (torch.Tensor) – Classification labels. Used to define irrelevant regions and divide the latent space based on predicted class. Shape should be torch.Size([B]).

  • num_classes (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

  • training (bool) – True if stage is TRAIN.

Returns:

  • Codebook’s indices for quantized representation (torch.Tensor)

  • Example

  • ——–

  • >>> inputs = torch.ones(3, 14, 25, 256)

  • >>> codebook = torch.randn(1024, 256)

  • >>> labels = torch.Tensor([1, 0, 2])

  • >>> print(VectorQuantization.apply(inputs, codebook, labels).shape)

  • torch.Size([3, 14, 25])

static backward(ctx, grad_output)[source]

Handles error in case grad() is called on the VQ operation.

class speechbrain.lobes.models.PIQ.VectorQuantizationStraightThrough(*args, **kwargs)[source]

Bases: Function

This class defines the forward method for vector quantization. As VQ is not differentiable, it approximates the gradient of the VQ as in https://arxiv.org/abs/1711.00937.

static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]

Applies VQ to vectors input with codebook as VQ dictionary and estimates gradients with a Straight-Through (id) approximation of the quantization steps.

Parameters:
  • inputs (torch.Tensor) – Hidden representations to quantize. Expected shape is torch.Size([B, W, H, C]).

  • codebook (torch.Tensor) – VQ-dictionary for quantization. Expected shape of torch.Size([K, C]) with K dictionary elements.

  • labels (torch.Tensor) – Classification labels. Used to define irrelevant regions and divide the latent space based on predicted class. Shape should be torch.Size([B]).

  • num_classes (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

  • training (bool) – True if stage is TRAIN.

Returns:

  • Quantized representation and codebook’s indices for quantized representation (tuple)

  • Example

  • ——–

  • >>> inputs = torch.ones(3, 14, 25, 256)

  • >>> codebook = torch.randn(1024, 256)

  • >>> labels = torch.Tensor([1, 0, 2])

  • >>> quant, quant_ind = VectorQuantizationStraightThrough.apply(inputs, codebook, labels)

  • >>> print(quant.shape, quant_ind.shape)

  • torch.Size([3, 14, 25, 256]) torch.Size([1050])

static backward(ctx, grad_output, grad_indices, labels=None, num_classes=None, activate_class_partitioning=True, shared_keys=10, training=True)[source]

Estimates gradient assuming vector quantization as identity function. (https://arxiv.org/abs/1711.00937)

class speechbrain.lobes.models.PIQ.Conv2dEncoder_v2(dim=256)[source]

Bases: Module

This class implements a convolutional encoder to extract classification embeddings from logspectra.

Parameters:

dim (int) – Number of channels of the extracted embeddings.

Returns:

  • Latent representations to feed inside classifier and/or intepreter.

  • Example

  • ——–

  • >>> inputs = torch.ones(3, 431, 513)

  • >>> model = Conv2dEncoder_v2()

  • >>> print(model(inputs).shape)

  • torch.Size([3, 256, 26, 32])

__init__(dim=256)[source]

Extracts embeddings from logspectrograms.

forward(x)[source]

Computes forward pass. :param x: Log-power spectrogram. Expected shape torch.Size([B, T, F]). :type x: torch.Tensor

Returns:

Embeddings

Return type:

torch.Tensor

training: bool
class speechbrain.lobes.models.PIQ.ResBlockAudio(dim)[source]

Bases: Module

This class implements a residual block.

Parameters:
  • dim (int) –

  • block. (Input channels of the tensor to process. Matches output channels of the residual) –

Returns:

Residual block output

Return type:

torch.Tensor

Example

>>> res = ResBlockAudio(128)
>>> x = torch.randn(2, 128, 16, 16)
>>> print(x.shape)
torch.Size([2, 128, 16, 16])
__init__(dim)[source]

Implements a residual block.

forward(x)[source]

Forward step.

Parameters:

x (torch.Tensor) – Tensor to process. Expected shape is torch.Size([B, C, H, W]).

Returns:

Residual block output

Return type:

torch.Tensor

training: bool
class speechbrain.lobes.models.PIQ.VectorQuantizedPSI_Audio(dim=128, K=512, numclasses=50, activate_class_partitioning=True, shared_keys=0, use_adapter=True, adapter_reduce_dim=True)[source]

Bases: Module

This class reconstructs log-power spectrograms from classifier’s representations.

Parameters:
  • dim (int) – Dimensionality of VQ vectors.

  • K (int) – Number of elements of VQ dictionary.

  • numclasses (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

  • use_adapter (bool) – True to learn an adapter for classifier’s representations.

  • adapter_reduce_dim (bool) – True if adapter should compress representations.

Returns:

  • Reconstructed log-power spectrograms, adapted classifier’s representations, quantized classifier’s representations. (tuple)

  • Example

  • ——–

  • >>> psi = VectorQuantizedPSI_Audio(dim=256, K=1024)

  • >>> x = torch.randn(2, 256, 16, 16)

  • >>> labels = torch.Tensor([0, 2])

  • >>> logspectra, hcat, z_q_x = psi(x, labels)

  • >>> print(logspectra.shape, hcat.shape, z_q_x.shape)

  • torch.Size([2, 1, 257, 257]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])

forward(hs, labels)[source]

Forward step. Reconstructs log-power based on provided label’s keys in VQ dictionary.

Parameters:
  • hs (torch.Tensor) – Classifier’s representations.

  • labels (torch.Tensor) – Predicted labels for classifier’s representations.

Returns:

Reconstructed log-power spectrogram, reduced classifier’s representations and quantized classifier’s representations.

Return type:

tuple

training: bool
class speechbrain.lobes.models.PIQ.VQEmbedding(K, D, numclasses=50, activate_class_partitioning=True, shared_keys=0)[source]

Bases: Module

Implements VQ Dictionary. Wraps VectorQuantization and VectorQuantizationStraightThrough. For more details refer to the specific class.

Parameters:
  • K (int) – Number of elements of VQ dictionary.

  • D (int) – Dimensionality of VQ vectors.

  • num_classes (int) – Number of possible classes

  • activate_class_partitioning (bool) – True if latent space should be quantized for different classes.

  • shared_keys (int) – Number of shared keys among classes.

forward(z_e_x, labels=None)[source]

Wraps VectorQuantization. Computes VQ-dictionary indices for input quantization. Note that this forward step is not differentiable.

Parameters:

z_e_x (torch.Tensor) – Input tensor to be quantized.

Returns:

  • Codebook’s indices for quantized representation (torch.Tensor)

  • Example

  • ——–

  • >>> inputs = torch.ones(3, 256, 14, 25)

  • >>> codebook = VQEmbedding(1024, 256)

  • >>> labels = torch.Tensor([1, 0, 2])

  • >>> print(codebook(inputs, labels).shape)

  • torch.Size([3, 14, 25])

straight_through(z_e_x, labels=None)[source]

Implements the vector quantization with straight through approximation of the gradient.

Parameters:
  • z_e_x (torch.Tensor) – Input tensor to be quantized.

  • labels (torch.Tensor) – Predicted class for input representations (used for latent space quantization).

Returns:

  • Straigth through quantized representation and quantized representation (tuple)

  • Example

  • ——–

  • >>> inputs = torch.ones(3, 256, 14, 25)

  • >>> codebook = VQEmbedding(1024, 256)

  • >>> labels = torch.Tensor([1, 0, 2])

  • >>> quant, quant_ind = codebook.straight_through(inputs, labels)

  • >>> print(quant.shape, quant_ind.shape)

  • torch.Size([3, 256, 14, 25]) torch.Size([3, 256, 14, 25])

training: bool