Source code for speechbrain.lobes.models.L2I

"""This file implements the necessary classes and functions to implement Listen-to-Interpret (L2I) interpretation method from https://arxiv.org/abs/2202.11479v2

 Authors
 * Cem Subakan 2022
 * Francesco Paissan 2022
"""
import torch.nn as nn
import torch.nn.functional as F
import torch

from speechbrain.lobes.models.PIQ import ResBlockAudio


[docs] class Psi(nn.Module): """Convolutional Layers to estimate NMF Activations from Classifier Representations Arguments --------- n_comp : int Number of NMF components (or equivalently number of neurons at the output per timestep) T: int The targeted length along the time dimension in_emb_dims: List with int elements A list with length 3 that contains the dimensionality of the input dimensions The list needs to match the number of channels in the input classifier representations The last entry should be the smallest entry Example ------- >>> inp = [torch.ones(2, 150, 6, 2), torch.ones(2, 100, 6, 2), torch.ones(2, 50, 12, 5)] >>> psi = Psi(n_comp=100, T=120, in_emb_dims=[150, 100, 50]) >>> h = psi(inp) >>> print(h.shape) torch.Size([2, 100, 120]) """
[docs] def __init__(self, n_comp=100, T=431, in_emb_dims=[2048, 1024, 512]): """ Computes NMF activations given classifier hidden representations """ super(Psi, self).__init__() self.in_emb_dims = in_emb_dims self.upsamp = nn.UpsamplingBilinear2d(scale_factor=(2, 2)) self.upsamp_time = nn.UpsamplingBilinear2d(size=(T, 1)) out_c = min(in_emb_dims) self.c1 = nn.Conv2d( in_emb_dims[0], out_c, kernel_size=3, padding="same" ) self.c2 = nn.Conv2d( in_emb_dims[1], out_c, kernel_size=3, padding="same" ) self.out_conv = nn.Conv2d(out_c, n_comp, kernel_size=3, padding="same") self.conv = nn.Sequential( nn.Conv2d(out_c * 3, out_c, kernel_size=3, padding="same"), nn.BatchNorm2d(out_c), nn.ReLU(), ) self.act = nn.ReLU()
[docs] def forward(self, inp): """This forward function returns the NMF time activations given classifier activations Arguments --------- inp: A length 3 list of classifier input representions. """ error = "in PSI doesn't match. The embedding dimensions need to be consistent with the list self.in_emb_dims" for i, in_emb_dim in enumerate(self.in_emb_dims): # sanity check on shapes assert inp[i].shape[1] == self.in_emb_dims[i], ( "Nr. of channels " + error ) assert inp[0].shape[2] == inp[1].shape[2], "Spatial dimension " + error assert inp[0].shape[3] == inp[1].shape[3], "Spatial dimension " + error assert 2 * inp[0].shape[3] == (inp[2].shape[3] - 1), ( "Spatial dimension " + error + f" 1st (idx 0) element has shape {inp[0].shape[3]} second element (idx 1) has shape {inp[2].shape[3]}" ) x1, x2, x3 = inp # upsample inp[0] and inp[1] time and frequency axis once x1 = self.upsamp(x1) x2 = self.upsamp(x2) # compress feature number to the min among given hidden representations x1 = self.act(self.c1(x1)) x2 = self.act(self.c2(x2)) # for compatibility with cnn14 fixed frequency dimension x1 = F.pad(x1, (0, 1, 0, 0)) x2 = F.pad(x2, (0, 1, 0, 0)) x = torch.cat((x1, x2, x3), axis=1) # upsample time axis and collapse freq x = self.upsamp_time(x) # mix contribution for the three hidden layers -- work on this when fixing training x = self.conv(x) x = self.act(self.out_conv(x)).squeeze(3) return x
[docs] class NMFDecoderAudio(nn.Module): """This class implements an NMF decoder Arguments --------- n_comp : int Number of NMF components n_freq : int The number of frequency bins in the NMF dictionary device : str The device to run the model Example: -------- >>> NMF_dec = NMFDecoderAudio(20, 210, device='cpu') >>> H = torch.rand(1, 20, 150) >>> Xhat = NMF_dec.forward(H) >>> print(Xhat.shape) torch.Size([1, 210, 150]) """ def __init__(self, n_comp=100, n_freq=513, device="cuda"): super(NMFDecoderAudio, self).__init__() self.W = nn.Parameter( 0.1 * torch.rand(n_freq, n_comp), requires_grad=True ) self.activ = nn.ReLU()
[docs] def forward(self, H): """The forward pass for NMF given the activations H Arguments: --------- H : torch.Tensor The activations Tensor with shape B x n_comp x T where B = Batchsize n_comp = number of NMF components T = number of timepoints """ # Assume input of shape n_batch x n_comp x T H = self.activ(H) temp = self.activ(self.W).unsqueeze(0) output = torch.einsum("bij, bjk -> bik", temp, H) return output
[docs] def return_W(self): """This function returns the NMF dictionary""" W = self.W return self.activ(W)
[docs] def weights_init(m): """ Applies Xavier initialization to network weights. Arguments --------- m : nn.Module Module to initialize. """ classname = m.__class__.__name__ if classname.find("Conv") != -1: try: nn.init.xavier_uniform_(m.weight.data) m.bias.data.fill_(0) except AttributeError: print("Skipping initialization of ", classname)
[docs] class PsiOptimized(nn.Module): """Convolutional Layers to estimate NMF Activations from Classifier Representations, optimized for log-spectra. Arguments --------- dim: int Dimension of the hidden representations (input to the classifier). K : int Number of NMF components (or equivalently number of neurons at the output per timestep) num_classes : int Number of possible classes. use_adapter : bool `True` if you wish to learn an adapter for the latent representations. adapter_reduce_dim: bool `True` if the adapter should compress the latent representations. Example ------- >>> inp = torch.randn(1, 256, 26, 32) >>> psi = PsiOptimized(dim=256, K=100, use_adapter=False, adapter_reduce_dim=False) >>> h, inp_ad= psi(inp) >>> print(h.shape, inp_ad.shape) torch.Size([1, 1, 417, 100]) torch.Size([1, 256, 26, 32]) """
[docs] def __init__( self, dim=128, K=100, numclasses=50, use_adapter=False, adapter_reduce_dim=True, ): """ Computes NMF activations from hidden state. """ super().__init__() self.use_adapter = use_adapter self.adapter_reduce_dim = adapter_reduce_dim if use_adapter: self.adapter = ResBlockAudio(dim) if adapter_reduce_dim: self.down = nn.Conv2d(dim, dim, 4, (2, 2), 1) self.up = nn.ConvTranspose2d(dim, dim, 4, (2, 2), 1) self.decoder = nn.Sequential( nn.ConvTranspose2d(dim, dim, 3, (2, 2), 1), nn.ReLU(True), nn.BatchNorm2d(dim), nn.ConvTranspose2d(dim, dim, 4, (2, 2), 1), nn.ReLU(), nn.BatchNorm2d(dim), nn.ConvTranspose2d(dim, dim, 4, (2, 2), 1), nn.ReLU(), nn.BatchNorm2d(dim), nn.ConvTranspose2d(dim, dim, 4, (2, 2), 1), nn.ReLU(), nn.BatchNorm2d(dim), nn.ConvTranspose2d(dim, 1, 12, 1, 1), nn.ReLU(), nn.Linear(513, K), nn.ReLU(), ) self.apply(weights_init)
[docs] def forward(self, hs): """ Computes forward step. Arguments ------- hs : torch.Tensor Latent representations (input to the classifier). Expected shape `torch.Size([B, C, H, W])`. Returns ------- NMF activations and adapted representations. Shape `torch.Size([B, 1, T, 100])`. : torch.Tensor """ if self.use_adapter: hcat = self.adapter(hs) else: hcat = hs if self.adapter_reduce_dim: hcat = self.down(hcat) z_q_x_st = self.up(hcat) out = self.decoder(z_q_x_st) else: out = self.decoder(hcat) return out, hcat
[docs] class Theta(nn.Module): """This class implements a linear classifier on top of NMF activations Arguments --------- n_comp : int Number of NMF components T : int Number of Timepoints in the NMF activations num_classes : int Number of classes that the classifier works with Example: -------- >>> theta = Theta(30, 120, 50) >>> H = torch.rand(1, 30, 120) >>> c_hat = theta.forward(H) >>> print(c_hat.shape) torch.Size([1, 50]) """ def __init__(self, n_comp=100, T=431, num_classes=50): super().__init__() # This linear layer collapses the time axis using "attention" based pooling self.hard_att = nn.Linear(T, 1, bias=False) # The Linear layer for classification self.classifier = nn.Sequential( nn.Linear(n_comp, num_classes, bias=False), nn.Softmax(dim=1) )
[docs] def forward(self, H): """We first collapse the time axis, and then pass through the linear layer Arguments: --------- H : torch.Tensor The activations Tensor with shape B x n_comp x T where B = Batchsize n_comp = number of NMF components T = number of timepoints """ theta_out = self.hard_att(H).squeeze(2) theta_out = self.classifier(theta_out) return theta_out
[docs] class NMFEncoder(nn.Module): """This class implements an NMF encoder with a convolutional network Arguments --------- n_freq : int The number of frequency bins in the NMF dictionary n_comp : int Number of NMF components Example: -------- >>> nmfencoder = NMFEncoder(513, 100) >>> X = torch.rand(1, 513, 240) >>> Hhat = nmfencoder(X) >>> print(Hhat.shape) torch.Size([1, 100, 240]) """ def __init__(self, n_freq, n_comp): super().__init__() self.convenc = nn.Sequential( nn.Conv1d(n_freq, 256, kernel_size=8, padding="same"), nn.ReLU(), nn.Conv1d(256, 128, kernel_size=8, padding="same"), nn.ReLU(), nn.Conv1d(128, n_comp, kernel_size=8, padding="same"), nn.ReLU(), )
[docs] def forward(self, X): """ Arguments: --------- X : torch.Tensor The input spectrogram Tensor with shape B x n_freq x T where B = Batchsize n_freq = nfft for the input spectrogram T = number of timepoints """ return self.convenc(X)