speechbrain.lobes.models.Cnn14 module
This file implements the CNN14 model from https://arxiv.org/abs/1912.10211
Authors * Cem Subakan 2022 * Francesco Paissan 2022
Summary
Classes:
This class implements the Cnn14 model from https://arxiv.org/abs/1912.10211 |
|
This class implements the convolutional block used in CNN14 |
Functions:
Initialize a Batchnorm layer. |
|
Initialize a Linear or Convolutional layer. |
Reference
- speechbrain.lobes.models.Cnn14.init_layer(layer)[source]
Initialize a Linear or Convolutional layer.
- class speechbrain.lobes.models.Cnn14.ConvBlock(in_channels, out_channels, norm_type)[source]
Bases:
ModuleThis class implements the convolutional block used in CNN14
- Parameters:
in_channels (int) – Number of input channels
out_channels (int) – Number of output channels
norm_type (str in ['bn', 'in', 'ln']) – The type of normalization
Example
--------
ConvBlock(10 (>>> convblock =)
20
'ln')
torch.rand(5 (>>> x =)
10
20
30)
convblock(x) (>>> y =)
print(y.shape) (>>>)
torch.Size([5
20
10
15])
- forward(x, pool_size=(2, 2), pool_type='avg')[source]
The forward pass for convblocks in CNN14
Arguments:
- xtorch.Tensor
input tensor with shape B x C_in x D1 x D2
- where B = Batchsize
C_in = Number of input channel D1 = Dimensionality of the first spatial dim D2 = Dimensionality of the second spatial dim
- pool_sizetuple with integer values
Amount of pooling at each layer
- pool_typestr in [‘max’, ‘avg’, ‘avg+max’]
The type of pooling
- class speechbrain.lobes.models.Cnn14.Cnn14(mel_bins, emb_dim, norm_type='bn', return_reps=False)[source]
Bases:
ModuleThis class implements the Cnn14 model from https://arxiv.org/abs/1912.10211
- Parameters:
mel_bins (int) – Number of mel frequency bins in the input
emb_dim (int) – The dimensionality of the output embeddings
norm_type (str in ['bn', 'in', 'ln']) – The type of normalization
return_reps (bool (default=False)) – If True the model returns intermediate representations as well for interpretation
Example
--------
Cnn14(120 (>>> cnn14 =)
256)
torch.rand(3 (>>> x =)
400
120)
cnn14.forward(x) (>>> h =)
print(h.shape) (>>>)
torch.Size([3
1
256])