"""
Assorted reusable neural network modules.
Authors
* Artem Ploujnikov 2023
"""
from torch import nn
from speechbrain.dataio.dataio import length_to_mask
[docs]
class DoneDetector(nn.Module):
"""A wrapper for the done detector using a model (e.g. a CRDNN) and
an output layer.
The goal of using a wrapper is to apply masking before the output layer
(e.g. Softmax) so that the model can't "cheat" by outputting probabilities
in the masked area
Arguments
---------
model: torch.nn.Module
the model used to make the prediction
out: torch.nn.Module
the output function
Example
-------
>>> import torch
>>> from torch import nn
>>> from speechbrain.nnet.activations import Softmax
>>> from speechbrain.nnet.containers import Sequential
>>> from speechbrain.nnet.linear import Linear
>>> from speechbrain.lobes.models.CRDNN import CRDNN
>>> crdnn = CRDNN(
... input_size=80,
... cnn_blocks=1,
... cnn_kernelsize=3,
... rnn_layers=1,
... rnn_neurons=16,
... dnn_blocks=1,
... dnn_neurons=16
... )
>>> model_out = Linear(n_neurons=1, input_size=16)
>>> model_act = nn.Sigmoid()
>>> model = Sequential(
... crdnn,
... model_out,
... model_act
... )
>>> out = Softmax(
... apply_log=False,
... )
>>> done_detector = DoneDetector(
... model=model,
... out=out,
... )
>>> preds = torch.randn(4, 10, 80) # Batch x Length x Feats
>>> length = torch.tensor([1., .8, .5, 1.])
>>> preds_len = done_detector(preds, length)
>>> preds_len.shape
torch.Size([4, 10, 1])
"""
def __init__(self, model, out):
super().__init__()
self.model = model
self.out = out
[docs]
def forward(self, feats, length=None):
"""Computes the forward pass
Arguments
---------
feats: torch.Tensor
the features used for the model (e.g. spectrograms)
length: torch.Tensor
a tensor of relative lengths
Returns
-------
preds: torch.Tensor
predictions
"""
out = self.model(feats)
if length is not None:
max_len = feats.size(1)
mask = length_to_mask(length=length * max_len, max_len=max_len)
out = out * mask.unsqueeze(-1)
out = self.out(out)
return out