speechbrain.lobes.models.PIQ module
This file implements the necessary classes and functions to implement Posthoc Interpretations via Quantization.
Authors * Cem Subakan 2023 * Francesco Paissan 2023
Summary
Classes:
This class implements a convolutional encoder to extract classification embeddings from logspectra. |
|
This class implements a residual block. |
|
Implements VQ Dictionary. |
|
This class defines the forward method for vector quantization. |
|
This class defines the forward method for vector quantization. |
|
This class reconstructs log-power spectrograms from classifier's representations. |
Functions:
This class returns binary matrix that indicates the irrelevant regions in the VQ-dictionary given the labels array |
|
Applies Xavier initialization to network weights. |
Reference
- speechbrain.lobes.models.PIQ.get_irrelevant_regions(labels, K, num_classes, N_shared=5, stage='TRAIN')[source]
This class returns binary matrix that indicates the irrelevant regions in the VQ-dictionary given the labels array
- Parameters:
labels (torch.tensor) – 1 dimensional torch.tensor of size [B]
K (int) – Number of keys in the dictionary
num_classes (int) – Number of possible classes
N_shared (int) – Number of shared keys
stage (str) – “TRAIN” or else
Example –
-------- –
torch.Tensor([1 (>>> labels =) –
0 –
2]) –
get_irrelevant_regions(labels (>>> irrelevant_regions =) –
20 –
3 –
5) –
print(irrelevant_regions.shape) (>>>) –
torch.Size([3 –
20]) –
- speechbrain.lobes.models.PIQ.weights_init(m)[source]
Applies Xavier initialization to network weights.
- class speechbrain.lobes.models.PIQ.VectorQuantization(*args, **kwargs)[source]
Bases:
Function
This class defines the forward method for vector quantization. As VQ is not differentiable, it returns a RuntimeError in case
.grad()
is called. Refer toVectorQuantizationStraightThrough
for a straight_through estimation of the gradient for the VQ operation.- static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]
Applies VQ to vectors
input
withcodebook
as VQ dictionary.- Parameters:
inputs (torch.Tensor) – Hidden representations to quantize. Expected shape is
torch.Size([B, W, H, C])
.codebook (torch.Tensor) – VQ-dictionary for quantization. Expected shape of
torch.Size([K, C])
with K dictionary elements.labels (torch.Tensor) – Classification labels. Used to define irrelevant regions and divide the latent space based on predicted class. Shape should be
torch.Size([B])
.num_classes (int) – Number of possible classes
activate_class_partitioning (bool) –
True
if latent space should be quantized for different classes.shared_keys (int) – Number of shared keys among classes.
training (bool) –
True
if stage is TRAIN.
- Returns:
Codebook’s indices for quantized representation (torch.Tensor)
Example
——–
>>> inputs = torch.ones(3, 14, 25, 256)
>>> codebook = torch.randn(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> print(VectorQuantization.apply(inputs, codebook, labels).shape)
torch.Size([3, 14, 25])
- class speechbrain.lobes.models.PIQ.VectorQuantizationStraightThrough(*args, **kwargs)[source]
Bases:
Function
This class defines the forward method for vector quantization. As VQ is not differentiable, it approximates the gradient of the VQ as in https://arxiv.org/abs/1711.00937.
- static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[source]
Applies VQ to vectors
input
withcodebook
as VQ dictionary and estimates gradients with a Straight-Through (id) approximation of the quantization steps.- Parameters:
inputs (torch.Tensor) – Hidden representations to quantize. Expected shape is
torch.Size([B, W, H, C])
.codebook (torch.Tensor) – VQ-dictionary for quantization. Expected shape of
torch.Size([K, C])
with K dictionary elements.labels (torch.Tensor) – Classification labels. Used to define irrelevant regions and divide the latent space based on predicted class. Shape should be
torch.Size([B])
.num_classes (int) – Number of possible classes
activate_class_partitioning (bool) –
True
if latent space should be quantized for different classes.shared_keys (int) – Number of shared keys among classes.
training (bool) –
True
if stage is TRAIN.
- Returns:
Quantized representation and codebook’s indices for quantized representation (tuple)
Example
——–
>>> inputs = torch.ones(3, 14, 25, 256)
>>> codebook = torch.randn(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> quant, quant_ind = VectorQuantizationStraightThrough.apply(inputs, codebook, labels)
>>> print(quant.shape, quant_ind.shape)
torch.Size([3, 14, 25, 256]) torch.Size([1050])
- static backward(ctx, grad_output, grad_indices, labels=None, num_classes=None, activate_class_partitioning=True, shared_keys=10, training=True)[source]
Estimates gradient assuming vector quantization as identity function. (https://arxiv.org/abs/1711.00937)
- class speechbrain.lobes.models.PIQ.Conv2dEncoder_v2(dim=256)[source]
Bases:
Module
This class implements a convolutional encoder to extract classification embeddings from logspectra.
- Parameters:
dim (int) – Number of channels of the extracted embeddings.
- Returns:
Latent representations to feed inside classifier and/or intepreter.
Example
——–
>>> inputs = torch.ones(3, 431, 513)
>>> model = Conv2dEncoder_v2()
>>> print(model(inputs).shape)
torch.Size([3, 256, 26, 32])
- class speechbrain.lobes.models.PIQ.ResBlockAudio(dim)[source]
Bases:
Module
This class implements a residual block.
- Parameters:
dim (int) –
block. (Input channels of the tensor to process. Matches output channels of the residual) –
- Returns:
Residual block output
- Return type:
Example
>>> res = ResBlockAudio(128) >>> x = torch.randn(2, 128, 16, 16) >>> print(x.shape) torch.Size([2, 128, 16, 16])
- forward(x)[source]
Forward step.
- Parameters:
x (torch.Tensor) – Tensor to process. Expected shape is
torch.Size([B, C, H, W])
.- Returns:
Residual block output
- Return type:
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSI_Audio(dim=128, K=512, numclasses=50, activate_class_partitioning=True, shared_keys=0, use_adapter=True, adapter_reduce_dim=True)[source]
Bases:
Module
This class reconstructs log-power spectrograms from classifier’s representations.
- Parameters:
dim (int) – Dimensionality of VQ vectors.
K (int) – Number of elements of VQ dictionary.
numclasses (int) – Number of possible classes
activate_class_partitioning (bool) –
True
if latent space should be quantized for different classes.shared_keys (int) – Number of shared keys among classes.
use_adapter (bool) –
True
to learn an adapter for classifier’s representations.adapter_reduce_dim (bool) –
True
if adapter should compress representations.
- Returns:
Reconstructed log-power spectrograms, adapted classifier’s representations, quantized classifier’s representations. (tuple)
Example
——–
>>> psi = VectorQuantizedPSI_Audio(dim=256, K=1024)
>>> x = torch.randn(2, 256, 16, 16)
>>> labels = torch.Tensor([0, 2])
>>> logspectra, hcat, z_q_x = psi(x, labels)
>>> print(logspectra.shape, hcat.shape, z_q_x.shape)
torch.Size([2, 1, 257, 257]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- forward(hs, labels)[source]
Forward step. Reconstructs log-power based on provided label’s keys in VQ dictionary.
- Parameters:
hs (torch.Tensor) – Classifier’s representations.
labels (torch.Tensor) – Predicted labels for classifier’s representations.
- Returns:
Reconstructed log-power spectrogram, reduced classifier’s representations and quantized classifier’s representations.
- Return type:
- class speechbrain.lobes.models.PIQ.VQEmbedding(K, D, numclasses=50, activate_class_partitioning=True, shared_keys=0)[source]
Bases:
Module
Implements VQ Dictionary. Wraps
VectorQuantization
andVectorQuantizationStraightThrough
. For more details refer to the specific class.- Parameters:
- forward(z_e_x, labels=None)[source]
Wraps VectorQuantization. Computes VQ-dictionary indices for input quantization. Note that this forward step is not differentiable.
- Parameters:
z_e_x (torch.Tensor) – Input tensor to be quantized.
- Returns:
Codebook’s indices for quantized representation (torch.Tensor)
Example
——–
>>> inputs = torch.ones(3, 256, 14, 25)
>>> codebook = VQEmbedding(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> print(codebook(inputs, labels).shape)
torch.Size([3, 14, 25])
- straight_through(z_e_x, labels=None)[source]
Implements the vector quantization with straight through approximation of the gradient.
- Parameters:
z_e_x (torch.Tensor) – Input tensor to be quantized.
labels (torch.Tensor) – Predicted class for input representations (used for latent space quantization).
- Returns:
Straigth through quantized representation and quantized representation (tuple)
Example
——–
>>> inputs = torch.ones(3, 256, 14, 25)
>>> codebook = VQEmbedding(1024, 256)
>>> labels = torch.Tensor([1, 0, 2])
>>> quant, quant_ind = codebook.straight_through(inputs, labels)
>>> print(quant.shape, quant_ind.shape)
torch.Size([3, 256, 14, 25]) torch.Size([3, 256, 14, 25])