Source code for speechbrain.lobes.models.Cnn14

""" This file implements the CNN14 model from https://arxiv.org/abs/1912.10211

 Authors
 * Cem Subakan 2022
 * Francesco Paissan 2022
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] def init_layer(layer): """Initialize a Linear or Convolutional layer.""" nn.init.xavier_uniform_(layer.weight) if hasattr(layer, "bias"): if layer.bias is not None: layer.bias.data.fill_(0.0)
[docs] def init_bn(bn): """Initialize a Batchnorm layer.""" bn.bias.data.fill_(0.0) bn.weight.data.fill_(1.0)
[docs] class ConvBlock(nn.Module): """This class implements the convolutional block used in CNN14 Arguments --------- in_channels : int Number of input channels out_channels : int Number of output channels norm_type : str in ['bn', 'in', 'ln'] The type of normalization Example ------- >>> convblock = ConvBlock(10, 20, 'ln') >>> x = torch.rand(5, 10, 20, 30) >>> y = convblock(x) >>> print(y.shape) torch.Size([5, 20, 10, 15]) """ def __init__(self, in_channels, out_channels, norm_type): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) self.conv2 = nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) self.norm_type = norm_type if norm_type == "bn": self.norm1 = nn.BatchNorm2d(out_channels) self.norm2 = nn.BatchNorm2d(out_channels) elif norm_type == "in": self.norm1 = nn.InstanceNorm2d( out_channels, affine=True, track_running_stats=True ) self.norm2 = nn.InstanceNorm2d( out_channels, affine=True, track_running_stats=True ) elif norm_type == "ln": self.norm1 = nn.GroupNorm(1, out_channels) self.norm2 = nn.GroupNorm(1, out_channels) else: raise ValueError("Unknown norm type {}".format(norm_type)) self.init_weight()
[docs] def init_weight(self): """ Initializes the model convolutional layers and the batchnorm layers """ init_layer(self.conv1) init_layer(self.conv2) init_bn(self.norm1) init_bn(self.norm2)
[docs] def forward(self, x, pool_size=(2, 2), pool_type="avg"): """The forward pass for convblocks in CNN14 Arguments --------- x : torch.Tensor input tensor with shape B x C_in x D1 x D2 where B = Batchsize C_in = Number of input channel D1 = Dimensionality of the first spatial dim D2 = Dimensionality of the second spatial dim pool_size : tuple with integer values Amount of pooling at each layer pool_type : str in ['max', 'avg', 'avg+max'] The type of pooling Returns ------- The output of one conv block """ x = F.relu_(self.norm1(self.conv1(x))) x = F.relu_(self.norm2(self.conv2(x))) if pool_type == "max": x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == "avg": x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == "avg+max": x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception("Incorrect pooling type!") return x
[docs] class Cnn14(nn.Module): """This class implements the Cnn14 model from https://arxiv.org/abs/1912.10211 Arguments --------- mel_bins : int Number of mel frequency bins in the input emb_dim : int The dimensionality of the output embeddings norm_type: str in ['bn', 'in', 'ln'] The type of normalization return_reps: bool (default=False) If True the model returns intermediate representations as well for interpretation l2i : bool If True, remove one of the outputs. Example ------- >>> cnn14 = Cnn14(120, 256) >>> x = torch.rand(3, 400, 120) >>> h = cnn14.forward(x) >>> print(h.shape) torch.Size([3, 1, 256]) """ def __init__( self, mel_bins, emb_dim, norm_type="bn", return_reps=False, l2i=False ): super(Cnn14, self).__init__() self.return_reps = return_reps self.l2i = l2i self.norm_type = norm_type if norm_type == "bn": self.norm0 = nn.BatchNorm2d(mel_bins) elif norm_type == "in": self.norm0 = nn.InstanceNorm2d( mel_bins, affine=True, track_running_stats=True ) elif norm_type == "ln": self.norm0 = nn.GroupNorm(1, mel_bins) else: raise ValueError("Unknown norm type {}".format(norm_type)) self.conv_block1 = ConvBlock( in_channels=1, out_channels=64, norm_type=norm_type ) self.conv_block2 = ConvBlock( in_channels=64, out_channels=128, norm_type=norm_type ) self.conv_block3 = ConvBlock( in_channels=128, out_channels=256, norm_type=norm_type ) self.conv_block4 = ConvBlock( in_channels=256, out_channels=512, norm_type=norm_type ) self.conv_block5 = ConvBlock( in_channels=512, out_channels=1024, norm_type=norm_type ) self.conv_block6 = ConvBlock( in_channels=1024, out_channels=emb_dim, norm_type=norm_type ) self.init_weight()
[docs] def init_weight(self): """ Initializes the model batch norm layer """ init_bn(self.norm0)
[docs] def forward(self, x): """ The forward pass for the CNN14 encoder Arguments --------- x : torch.Tensor input tensor with shape B x C_in x D1 x D2 where B = Batchsize C_in = Number of input channel D1 = Dimensionality of the first spatial dim D2 = Dimensionality of the second spatial dim Returns ------- Outputs of CNN14 encoder """ if x.dim() == 3: x = x.unsqueeze(1) x = x.transpose(1, 3) x = self.norm0(x) x = x.transpose(1, 3) x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x, p=0.2, training=self.training) x4_out = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x4_out, p=0.2, training=self.training) x3_out = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x3_out, p=0.2, training=self.training) x2_out = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") x = F.dropout(x2_out, p=0.2, training=self.training) x1_out = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") x = F.dropout(x1_out, p=0.2, training=self.training) x = torch.mean(x, dim=3) (x1, _) = torch.max(x, dim=2) x2 = torch.mean(x, dim=2) x = x1 + x2 # [B x 1 x emb_dim] if not self.return_reps: return x.unsqueeze(1) if self.l2i: return x.unsqueeze(1), (x1_out, x2_out, x3_out) else: return x.unsqueeze(1), (x1_out, x2_out, x3_out, x4_out)
[docs] class CNN14PSI(nn.Module): """ This class estimates a mel-domain saliency mask Arguments --------- dim : int Dimensionality of the embeddings Returns ------- Estimated saliency map (before sigmoid) Example ------- >>> from speechbrain.lobes.models.Cnn14 import Cnn14 >>> classifier_embedder = Cnn14(mel_bins=80, emb_dim=2048, return_reps=True) >>> x = torch.randn(2, 201, 80) >>> _, hs = classifier_embedder(x) >>> psimodel = CNN14PSI(2048) >>> xhat = psimodel.forward(hs) >>> print(xhat.shape) torch.Size([2, 1, 201, 80]) """ def __init__( self, dim=128, ): super().__init__() self.convt1 = nn.ConvTranspose2d(dim, dim, 3, (2, 2), 1) self.convt2 = nn.ConvTranspose2d(dim // 2, dim, 3, (2, 2), 1) self.convt3 = nn.ConvTranspose2d(dim, dim, (7, 4), (2, 4), 1) self.convt4 = nn.ConvTranspose2d(dim // 4, dim, (5, 4), (2, 2), 1) self.convt5 = nn.ConvTranspose2d(dim, dim, (3, 3), (2, 2), 1) self.convt6 = nn.ConvTranspose2d(dim // 8, dim, (3, 3), (2, 2), 1) self.convt7 = nn.ConvTranspose2d(dim, dim, (4, 3), (2, 2), 0) self.convt8 = nn.ConvTranspose2d(dim, 1, (3, 4), (2, 2), 0) self.nonl = nn.ReLU(True)
[docs] def forward(self, hs, labels=None): """ Forward step. Given the classifier representations estimates a saliency map. Arguments --------- hs : torch.Tensor Classifier's representations. labels : None Unused Returns ------- xhat : torch.Tensor Estimated saliency map (before sigmoid) """ h1 = self.convt1(hs[0]) h1 = self.nonl(h1) h2 = self.convt2(hs[1]) h2 = self.nonl(h2) h = h1 + h2 h3 = self.convt3(h) h3 = self.nonl(h3) h4 = self.convt4(hs[2]) h4 = self.nonl(h4) h = h3 + h4 h5 = self.convt5(h) h5 = self.nonl(h5) h6 = self.convt6(hs[3]) h6 = self.nonl(h6) h = h5 + h6 h = self.convt7(h) h = self.nonl(h) xhat = self.convt8(h) return xhat
[docs] class CNN14PSI_stft(nn.Module): """ This class estimates a saliency map on the STFT domain, given classifier representations. Arguments --------- dim : int Dimensionality of the input representations. outdim : int Defines the number of output channels in the saliency map. Example ------- >>> from speechbrain.lobes.models.Cnn14 import Cnn14 >>> classifier_embedder = Cnn14(mel_bins=80, emb_dim=2048, return_reps=True) >>> x = torch.randn(2, 201, 80) >>> _, hs = classifier_embedder(x) >>> psimodel = CNN14PSI_stft(2048, 1) >>> xhat = psimodel.forward(hs) >>> print(xhat.shape) torch.Size([2, 1, 201, 513]) """ def __init__(self, dim=128, outdim=1): super().__init__() self.convt1 = nn.ConvTranspose2d(dim, dim, 3, (2, 4), 1) self.convt2 = nn.ConvTranspose2d(dim // 2, dim, 3, (2, 4), 1) self.convt3 = nn.ConvTranspose2d(dim, dim, (7, 4), (2, 4), 1) self.convt4 = nn.ConvTranspose2d(dim // 4, dim, (5, 4), (2, 4), 1) self.convt5 = nn.ConvTranspose2d(dim, dim // 2, (3, 5), (2, 2), 1) self.convt6 = nn.ConvTranspose2d(dim // 8, dim // 2, (3, 3), (2, 4), 1) self.convt7 = nn.ConvTranspose2d( dim // 2, dim // 4, (4, 3), (2, 2), (0, 5) ) self.convt8 = nn.ConvTranspose2d( dim // 4, dim // 8, (3, 4), (2, 2), (0, 2) ) self.convt9 = nn.ConvTranspose2d(dim // 8, outdim, (1, 5), (1, 4), 0) self.nonl = nn.ReLU(True)
[docs] def forward(self, hs): """ Forward step to estimate the saliency map Arguments -------- hs : torch.Tensor Classifier's representations. Returns -------- xhat : torch.Tensor An Estimate for the saliency map """ h1 = self.convt1(hs[0]) h1 = self.nonl(h1) h2 = self.convt2(hs[1]) h2 = self.nonl(h2) h = h1 + h2 h3 = self.convt3(h) h3 = self.nonl(h3) h4 = self.convt4(hs[2]) h4 = self.nonl(h4) h = h3 + h4 h5 = self.convt5(h) h5 = self.nonl(h5) h6 = self.convt6(hs[3]) h6 = self.nonl(h6) h = h5 + h6 h = self.convt7(h) h = self.nonl(h) h = self.convt8(h) xhat = self.convt9(h) return xhat