speechbrain.lobes.models.L2I module
This file implements the necessary classes and functions to implement Listen-to-Interpret (L2I) interpretation method from https://arxiv.org/abs/2202.11479v2
Authors * Cem Subakan 2022 * Francesco Paissan 2022
Summary
Classes:
This class implements an NMF decoder |
|
This class implements an NMF encoder with a convolutional network |
|
Convolutional Layers to estimate NMF Activations from Classifier Representations |
|
Convolutional Layers to estimate NMF Activations from Classifier Representations, optimized for log-spectra. |
|
This class implements a linear classifier on top of NMF activations |
Functions:
Applies Xavier initialization to network weights. |
Reference
- class speechbrain.lobes.models.L2I.Psi(n_comp=100, T=431, in_emb_dims=[2048, 1024, 512])[source]
Bases:
Module
Convolutional Layers to estimate NMF Activations from Classifier Representations
- Parameters:
n_comp (int) – Number of NMF components (or equivalently number of neurons at the output per timestep)
T (int) – The targeted length along the time dimension
in_emb_dims (List with int elements) – A list with length 3 that contains the dimensionality of the input dimensions The list needs to match the number of channels in the input classifier representations The last entry should be the smallest entry
Example
>>> inp = [torch.ones(2, 150, 6, 2), torch.ones(2, 100, 6, 2), torch.ones(2, 50, 12, 5)] >>> psi = Psi(n_comp=100, T=120, in_emb_dims=[150, 100, 50]) >>> h = psi(inp) >>> print(h.shape) torch.Size([2, 100, 120])
- __init__(n_comp=100, T=431, in_emb_dims=[2048, 1024, 512])[source]
Computes NMF activations given classifier hidden representations
- class speechbrain.lobes.models.L2I.NMFDecoderAudio(n_comp=100, n_freq=513, device='cuda')[source]
Bases:
Module
This class implements an NMF decoder
- Parameters:
n_comp (int) – Number of NMF components
n_freq (int) – The number of frequency bins in the NMF dictionary
device (str) – The device to run the model
Example –
-------- –
NMFDecoderAudio(20 (>>> NMF_dec =) –
210 –
device='cpu') –
torch.rand(1 (>>> H =) –
20 –
150) –
NMF_dec.forward(H) (>>> Xhat =) –
print(Xhat.shape) (>>>) –
torch.Size([1 –
210 –
150]) –
- speechbrain.lobes.models.L2I.weights_init(m)[source]
Applies Xavier initialization to network weights.
- Parameters:
m (nn.Module) – Module to initialize.
- class speechbrain.lobes.models.L2I.PsiOptimized(dim=128, K=100, numclasses=50, use_adapter=False, adapter_reduce_dim=True)[source]
Bases:
Module
Convolutional Layers to estimate NMF Activations from Classifier Representations, optimized for log-spectra.
- Parameters:
dim (int) – Dimension of the hidden representations (input to the classifier).
K (int) – Number of NMF components (or equivalently number of neurons at the output per timestep)
num_classes (int) – Number of possible classes.
use_adapter (bool) –
True
if you wish to learn an adapter for the latent representations.adapter_reduce_dim (bool) –
True
if the adapter should compress the latent representations.
Example
>>> inp = torch.randn(1, 256, 26, 32) >>> psi = PsiOptimized(dim=256, K=100, use_adapter=False, adapter_reduce_dim=False) >>> h, inp_ad= psi(inp) >>> print(h.shape, inp_ad.shape) torch.Size([1, 1, 417, 100]) torch.Size([1, 256, 26, 32])
- __init__(dim=128, K=100, numclasses=50, use_adapter=False, adapter_reduce_dim=True)[source]
Computes NMF activations from hidden state.
- class speechbrain.lobes.models.L2I.Theta(n_comp=100, T=431, num_classes=50)[source]
Bases:
Module
This class implements a linear classifier on top of NMF activations
- Parameters:
n_comp (int) – Number of NMF components
T (int) – Number of Timepoints in the NMF activations
num_classes (int) – Number of classes that the classifier works with
Example –
-------- –
Theta(30 (>>> theta =) –
120 –
50) –
torch.rand(1 (>>> H =) –
30 –
120) –
theta.forward(H) (>>> c_hat =) –
print(c_hat.shape) (>>>) –
torch.Size([1 –
50]) –
- class speechbrain.lobes.models.L2I.NMFEncoder(n_freq, n_comp)[source]
Bases:
Module
This class implements an NMF encoder with a convolutional network
- Parameters: