speechbrain.lobes.models.Cnn14 module
This file implements the CNN14 model from https://arxiv.org/abs/1912.10211
Authors * Cem Subakan 2022 * Francesco Paissan 2022
Summary
Classes:
This class implements the Cnn14 model from https://arxiv.org/abs/1912.10211 |
|
This class implements the convolutional block used in CNN14 |
Functions:
Initialize a Batchnorm layer. |
|
Initialize a Linear or Convolutional layer. |
Reference
- speechbrain.lobes.models.Cnn14.init_layer(layer)[source]
Initialize a Linear or Convolutional layer.
- class speechbrain.lobes.models.Cnn14.ConvBlock(in_channels, out_channels, norm_type)[source]
Bases:
Module
This class implements the convolutional block used in CNN14
- Parameters:
in_channels (int) – Number of input channels
out_channels (int) – Number of output channels
norm_type (str in ['bn', 'in', 'ln']) – The type of normalization
Example –
-------- –
ConvBlock(10 (>>> convblock =) –
20 –
'ln') –
torch.rand(5 (>>> x =) –
10 –
20 –
30) –
convblock(x) (>>> y =) –
print(y.shape) (>>>) –
torch.Size([5 –
20 –
10 –
15]) –
- forward(x, pool_size=(2, 2), pool_type='avg')[source]
The forward pass for convblocks in CNN14
Arguments:
- xtorch.Tensor
input tensor with shape B x C_in x D1 x D2
- where B = Batchsize
C_in = Number of input channel D1 = Dimensionality of the first spatial dim D2 = Dimensionality of the second spatial dim
- pool_sizetuple with integer values
Amount of pooling at each layer
- pool_typestr in [‘max’, ‘avg’, ‘avg+max’]
The type of pooling
- class speechbrain.lobes.models.Cnn14.Cnn14(mel_bins, emb_dim, norm_type='bn', return_reps=False)[source]
Bases:
Module
This class implements the Cnn14 model from https://arxiv.org/abs/1912.10211
- Parameters:
mel_bins (int) – Number of mel frequency bins in the input
emb_dim (int) – The dimensionality of the output embeddings
norm_type (str in ['bn', 'in', 'ln']) – The type of normalization
return_reps (bool (default=False)) – If True the model returns intermediate representations as well for interpretation
Example –
-------- –
Cnn14(120 (>>> cnn14 =) –
256) –
torch.rand(3 (>>> x =) –
400 –
120) –
cnn14.forward(x) (>>> h =) –
print(h.shape) (>>>) –
torch.Size([3 –
1 –
256]) –