speechbrain.nnet.losses module

Losses for training neural networks.

Authors
  • Mirco Ravanelli 2020

  • Samuele Cornell 2020

  • Hwidong Na 2020

  • Yan Gao 2020

  • Titouan Parcollet 2020

Summary

Classes:

AdditiveAngularMargin

An implementation of Additive Angular Margin (AAM) proposed in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)

AngularMargin

An implementation of Angular Margin (AM) proposed in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)

LogSoftmaxWrapper

returns
  • loss (torch.Tensor) -- Learning loss

PitWrapper

Permutation Invariant Wrapper to allow Permutation Invariant Training (PIT) with existing losses.

Functions:

bce_loss

Computes binary cross-entropy (BCE) loss.

cal_si_snr

Calculate SI-SNR.

cal_snr

Calculate binaural channel SNR. Arguments: --------- source: [T, E, B, C], Where B is batch size, T is the length of the sources, E is binaural channels, C is the number of sources the ordering is made so that this loss is compatible with the class PitWrapper. estimate_source: [T, E, B, C] The estimated source.

ce_kd

Simple version of distillation for cross-entropy loss.

classification_error

Computes the classification error at frame or batch level.

compute_masked_loss

Compute the true average loss of a set of waveforms of unequal length.

ctc_loss

CTC loss.

ctc_loss_kd

Knowledge distillation for CTC loss.

get_mask

param source

get_si_snr_with_pitwrapper

This function wraps si_snr calculation with the speechbrain pit-wrapper.

get_snr_with_pitwrapper

This function wraps si_snr calculation with the speechbrain pit-wrapper. Arguments: --------- source: [B, T, E, C], Where B is the batch size, T is the length of the sources, E is binaural channels, C is the number of sources the ordering is made so that this loss is compatible with the class PitWrapper. estimate_source: [B, T, E, C] The estimated source.

kldiv_loss

Computes the KL-divergence error at the batch level.

l1_loss

Compute the true l1 loss, accounting for length differences.

mse_loss

Compute the true mean squared error, accounting for length differences.

nll_loss

Computes negative log likelihood loss.

nll_loss_kd

Knowledge distillation for negative log-likelihood loss.

transducer_loss

Transducer loss, see speechbrain/nnet/loss/transducer_loss.py.

truncate

Ensure that predictions and targets are the same length.

Reference

speechbrain.nnet.losses.transducer_loss(logits, targets, input_lens, target_lens, blank_index, reduction='mean', use_torchaudio=True)[source]

Transducer loss, see speechbrain/nnet/loss/transducer_loss.py.

Parameters
  • logits (torch.Tensor) – Predicted tensor, of shape [batch, maxT, maxU, num_labels].

  • targets (torch.Tensor) – Target tensor, without any blanks, of shape [batch, target_len].

  • input_lens (torch.Tensor) – Length of each utterance.

  • target_lens (torch.Tensor) – Length of each target sequence.

  • blank_index (int) – The location of the blank symbol among the label indices.

  • reduction (str) – Specifies the reduction to apply to the output: ‘mean’ | ‘batchmean’ | ‘sum’.

  • use_torchaudio (bool) – If True, use Transducer loss implementation from torchaudio, otherwise, use Speechbrain Numba implementation.

class speechbrain.nnet.losses.PitWrapper(base_loss)[source]

Bases: Module

Permutation Invariant Wrapper to allow Permutation Invariant Training (PIT) with existing losses.

Permutation invariance is calculated over the sources/classes axis which is assumed to be the rightmost dimension: predictions and targets tensors are assumed to have shape [batch, …, channels, sources].

Parameters

base_loss (function) – Base loss function, e.g. torch.nn.MSELoss. It is assumed that it takes two arguments: predictions and targets and no reduction is performed. (if a pytorch loss is used, the user must specify reduction=”none”).

Returns

pit_loss – Torch module supporting forward method for PIT.

Return type

torch.nn.Module

Example

>>> pit_mse = PitWrapper(nn.MSELoss(reduction="none"))
>>> targets = torch.rand((2, 32, 4))
>>> p = (3, 0, 2, 1)
>>> predictions = targets[..., p]
>>> loss, opt_p = pit_mse(predictions, targets)
>>> loss
tensor([0., 0.])
reorder_tensor(tensor, p)[source]
Parameters
  • tensor (torch.Tensor) – Tensor to reorder given the optimal permutation, of shape [batch, …, sources].

  • p (list of tuples) – List of optimal permutations, e.g. for batch=2 and n_sources=3 [(0, 1, 2), (0, 2, 1].

Returns

reordered – Reordered tensor given permutation p.

Return type

torch.Tensor

forward(preds, targets)[source]
Parameters
  • preds (torch.Tensor) – Network predictions tensor, of shape [batch, channels, …, sources].

  • targets (torch.Tensor) – Target tensor, of shape [batch, channels, …, sources].

Returns

  • loss (torch.Tensor) – Permutation invariant loss for current examples, tensor of shape [batch]

  • perms (list) – List of indexes for optimal permutation of the inputs over sources. e.g., [(0, 1, 2), (2, 1, 0)] for three sources and 2 examples per batch.

training: bool
speechbrain.nnet.losses.ctc_loss(log_probs, targets, input_lens, target_lens, blank_index, reduction='mean')[source]

CTC loss.

Parameters
  • predictions (torch.Tensor) – Predicted tensor, of shape [batch, time, chars].

  • targets (torch.Tensor) – Target tensor, without any blanks, of shape [batch, target_len]

  • input_lens (torch.Tensor) – Length of each utterance.

  • target_lens (torch.Tensor) – Length of each target sequence.

  • blank_index (int) – The location of the blank symbol among the character indexes.

  • reduction (str) – What reduction to apply to the output. ‘mean’, ‘sum’, ‘batch’, ‘batchmean’, ‘none’. See pytorch for ‘mean’, ‘sum’, ‘none’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

speechbrain.nnet.losses.l1_loss(predictions, targets, length=None, allowed_len_diff=3, reduction='mean')[source]

Compute the true l1 loss, accounting for length differences.

Parameters
  • predictions (torch.Tensor) – Predicted tensor, of shape [batch, time, *].

  • targets (torch.Tensor) – Target tensor with the same size as predicted tensor.

  • length (torch.Tensor) – Length of each utterance for computing true error with a mask.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

  • reduction (str) – Options are ‘mean’, ‘batch’, ‘batchmean’, ‘sum’. See pytorch for ‘mean’, ‘sum’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

Example

>>> probs = torch.tensor([[0.9, 0.1, 0.1, 0.9]])
>>> l1_loss(probs, torch.tensor([[1., 0., 0., 1.]]))
tensor(0.1000)
speechbrain.nnet.losses.mse_loss(predictions, targets, length=None, allowed_len_diff=3, reduction='mean')[source]

Compute the true mean squared error, accounting for length differences.

Parameters
  • predictions (torch.Tensor) – Predicted tensor, of shape [batch, time, *].

  • targets (torch.Tensor) – Target tensor with the same size as predicted tensor.

  • length (torch.Tensor) – Length of each utterance for computing true error with a mask.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

  • reduction (str) – Options are ‘mean’, ‘batch’, ‘batchmean’, ‘sum’. See pytorch for ‘mean’, ‘sum’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

Example

>>> probs = torch.tensor([[0.9, 0.1, 0.1, 0.9]])
>>> mse_loss(probs, torch.tensor([[1., 0., 0., 1.]]))
tensor(0.0100)
speechbrain.nnet.losses.classification_error(probabilities, targets, length=None, allowed_len_diff=3, reduction='mean')[source]

Computes the classification error at frame or batch level.

Parameters
  • probabilities (torch.Tensor) – The posterior probabilities of shape [batch, prob] or [batch, frames, prob]

  • targets (torch.Tensor) – The targets, of shape [batch] or [batch, frames]

  • length (torch.Tensor) – Length of each utterance, if frame-level loss is desired.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

  • reduction (str) – Options are ‘mean’, ‘batch’, ‘batchmean’, ‘sum’. See pytorch for ‘mean’, ‘sum’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

Example

>>> probs = torch.tensor([[[0.9, 0.1], [0.1, 0.9]]])
>>> classification_error(probs, torch.tensor([1, 1]))
tensor(0.5000)
speechbrain.nnet.losses.nll_loss(log_probabilities, targets, length=None, label_smoothing=0.0, allowed_len_diff=3, reduction='mean')[source]

Computes negative log likelihood loss.

Parameters
  • log_probabilities (torch.Tensor) – The probabilities after log has been applied. Format is [batch, log_p] or [batch, frames, log_p].

  • targets (torch.Tensor) – The targets, of shape [batch] or [batch, frames].

  • length (torch.Tensor) – Length of each utterance, if frame-level loss is desired.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

  • reduction (str) – Options are ‘mean’, ‘batch’, ‘batchmean’, ‘sum’. See pytorch for ‘mean’, ‘sum’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

Example

>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9]])
>>> nll_loss(torch.log(probs), torch.tensor([1, 1]))
tensor(1.2040)
speechbrain.nnet.losses.bce_loss(inputs, targets, length=None, weight=None, pos_weight=None, reduction='mean', allowed_len_diff=3, label_smoothing=0.0)[source]

Computes binary cross-entropy (BCE) loss. It also applies the sigmoid function directly (this improves the numerical stability).

Parameters
  • inputs (torch.Tensor) – The output before applying the final softmax Format is [batch[, 1]?] or [batch, frames[, 1]?]. (Works with or without a singleton dimension at the end).

  • targets (torch.Tensor) – The targets, of shape [batch] or [batch, frames].

  • length (torch.Tensor) – Length of each utterance, if frame-level loss is desired.

  • weight (torch.Tensor) – A manual rescaling weight if provided it’s repeated to match input tensor shape.

  • pos_weight (torch.Tensor) – A weight of positive examples. Must be a vector with length equal to the number of classes.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

  • reduction (str) – Options are ‘mean’, ‘batch’, ‘batchmean’, ‘sum’. See pytorch for ‘mean’, ‘sum’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

Example

>>> inputs = torch.tensor([10.0, -6.0])
>>> targets = torch.tensor([1, 0])
>>> bce_loss(inputs, targets)
tensor(0.0013)
speechbrain.nnet.losses.kldiv_loss(log_probabilities, targets, length=None, label_smoothing=0.0, allowed_len_diff=3, pad_idx=0, reduction='mean')[source]

Computes the KL-divergence error at the batch level. This loss applies label smoothing directly to the targets

Parameters
  • probabilities (torch.Tensor) – The posterior probabilities of shape [batch, prob] or [batch, frames, prob].

  • targets (torch.Tensor) – The targets, of shape [batch] or [batch, frames].

  • length (torch.Tensor) – Length of each utterance, if frame-level loss is desired.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

  • reduction (str) – Options are ‘mean’, ‘batch’, ‘batchmean’, ‘sum’. See pytorch for ‘mean’, ‘sum’. The ‘batch’ option returns one loss per item in the batch, ‘batchmean’ returns sum / batch size.

Example

>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9]])
>>> kldiv_loss(torch.log(probs), torch.tensor([1, 1]))
tensor(1.2040)
speechbrain.nnet.losses.truncate(predictions, targets, allowed_len_diff=3)[source]

Ensure that predictions and targets are the same length.

Parameters
  • predictions (torch.Tensor) – First tensor for checking length.

  • targets (torch.Tensor) – Second tensor for checking length.

  • allowed_len_diff (int) – Length difference that will be tolerated before raising an exception.

speechbrain.nnet.losses.compute_masked_loss(loss_fn, predictions, targets, length=None, label_smoothing=0.0, reduction='mean')[source]

Compute the true average loss of a set of waveforms of unequal length.

Parameters
  • loss_fn (function) – A function for computing the loss taking just predictions and targets. Should return all the losses, not a reduction (e.g. reduction=”none”).

  • predictions (torch.Tensor) – First argument to loss function.

  • targets (torch.Tensor) – Second argument to loss function.

  • length (torch.Tensor) – Length of each utterance to compute mask. If None, global average is computed and returned.

  • label_smoothing (float) – The proportion of label smoothing. Should only be used for NLL loss. Ref: Regularizing Neural Networks by Penalizing Confident Output Distributions. https://arxiv.org/abs/1701.06548

  • reduction (str) – One of ‘mean’, ‘batch’, ‘batchmean’, ‘none’ where ‘mean’ returns a single value and ‘batch’ returns one per item in the batch and ‘batchmean’ is sum / batch_size and ‘none’ returns all.

speechbrain.nnet.losses.get_si_snr_with_pitwrapper(source, estimate_source)[source]

This function wraps si_snr calculation with the speechbrain pit-wrapper.

source: [B, T, C],

Where B is the batch size, T is the length of the sources, C is the number of sources the ordering is made so that this loss is compatible with the class PitWrapper.

estimate_source: [B, T, C]

The estimated source.

>>> x = torch.arange(600).reshape(3, 100, 2)
>>> xhat = x[:, :, (1, 0)]
>>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
>>> print(si_snr)
tensor([135.2284, 135.2284, 135.2284])
speechbrain.nnet.losses.get_snr_with_pitwrapper(source, estimate_source)[source]

This function wraps si_snr calculation with the speechbrain pit-wrapper. Arguments: ——— source: [B, T, E, C],

Where B is the batch size, T is the length of the sources, E is binaural channels, C is the number of sources the ordering is made so that this loss is compatible with the class PitWrapper.

estimate_source: [B, T, E, C]

The estimated source.

speechbrain.nnet.losses.cal_si_snr(source, estimate_source)[source]

Calculate SI-SNR.

source: [T, B, C],

Where B is batch size, T is the length of the sources, C is the number of sources the ordering is made so that this loss is compatible with the class PitWrapper.

estimate_source: [T, B, C]

The estimated source.

>>> import numpy as np
>>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
>>> xhat = x[:, (1, 0)]
>>> x = x.unsqueeze(-1).repeat(1, 1, 2)
>>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
>>> si_snr = -cal_si_snr(x, xhat)
>>> print(si_snr)
tensor([[[ 25.2142, 144.1789],
         [130.9283,  25.2142]]])
speechbrain.nnet.losses.cal_snr(source, estimate_source)[source]

Calculate binaural channel SNR. Arguments: ——— source: [T, E, B, C],

Where B is batch size, T is the length of the sources, E is binaural channels, C is the number of sources the ordering is made so that this loss is compatible with the class PitWrapper.

estimate_source: [T, E, B, C]

The estimated source.

speechbrain.nnet.losses.get_mask(source, source_lengths)[source]
Parameters
  • source ([T, B, C]) –

  • source_lengths ([B]) –

Returns

  • mask ([T, B, 1])

  • Example

  • ———

  • >>> source = torch.randn(4, 3, 2)

  • >>> source_lengths = torch.Tensor([2, 1, 4]).int()

  • >>> mask = get_mask(source, source_lengths)

  • >>> print(mask)

  • tensor([[[1.], – [1.], [1.]],

  • <BLANKLINE>

    [[1.],

    [0.], [1.]],

  • <BLANKLINE>

    [[0.],

    [0.], [1.]],

  • <BLANKLINE>

    [[0.],

    [0.], [1.]]])

class speechbrain.nnet.losses.AngularMargin(margin=0.0, scale=1.0)[source]

Bases: Module

An implementation of Angular Margin (AM) proposed in the following paper: ‘’’Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition’’’ (https://arxiv.org/abs/1906.07317)

Parameters
  • margin (float) – The margin for cosine similiarity

  • scale (float) – The scale for cosine similiarity

Returns

predictions

Return type

torch.Tensor

Example

>>> pred = AngularMargin()
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0.,  1.] ])
>>> predictions = pred(outputs, targets)
>>> predictions[:,0] > predictions[:,1]
tensor([ True, False,  True, False])
forward(outputs, targets)[source]

Compute AM between two tensors

Parameters
  • outputs (torch.Tensor) – The outputs of shape [N, C], cosine similarity is required.

  • targets (torch.Tensor) – The targets of shape [N, C], where the margin is applied for.

Returns

predictions

Return type

torch.Tensor

training: bool
class speechbrain.nnet.losses.AdditiveAngularMargin(margin=0.0, scale=1.0, easy_margin=False)[source]

Bases: AngularMargin

An implementation of Additive Angular Margin (AAM) proposed in the following paper: ‘’’Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition’’’ (https://arxiv.org/abs/1906.07317)

Parameters
  • margin (float) – The margin for cosine similiarity.

  • scale (float) – The scale for cosine similiarity.

Returns

predictions – Tensor.

Return type

torch.Tensor

Example

>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0.,  1.] ])
>>> pred = AdditiveAngularMargin()
>>> predictions = pred(outputs, targets)
>>> predictions[:,0] > predictions[:,1]
tensor([ True, False,  True, False])
forward(outputs, targets)[source]

Compute AAM between two tensors

Parameters
  • outputs (torch.Tensor) – The outputs of shape [N, C], cosine similarity is required.

  • targets (torch.Tensor) – The targets of shape [N, C], where the margin is applied for.

Returns

predictions

Return type

torch.Tensor

training: bool
class speechbrain.nnet.losses.LogSoftmaxWrapper(loss_fn)[source]

Bases: Module

Returns

  • loss (torch.Tensor) – Learning loss

  • predictions (torch.Tensor) – Log probabilities

Example

>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> outputs = outputs.unsqueeze(1)
>>> targets = torch.tensor([ [0], [1], [0], [1] ])
>>> log_prob = LogSoftmaxWrapper(nn.Identity())
>>> loss = log_prob(outputs, targets)
>>> 0 <= loss < 1
tensor(True)
>>> log_prob = LogSoftmaxWrapper(AngularMargin(margin=0.2, scale=32))
>>> loss = log_prob(outputs, targets)
>>> 0 <= loss < 1
tensor(True)
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> log_prob = LogSoftmaxWrapper(AdditiveAngularMargin(margin=0.3, scale=32))
>>> loss = log_prob(outputs, targets)
>>> 0 <= loss < 1
tensor(True)
forward(outputs, targets, length=None)[source]
Parameters
  • outputs (torch.Tensor) – Network output tensor, of shape [batch, 1, outdim].

  • targets (torch.Tensor) – Target tensor, of shape [batch, 1].

Returns

loss – Loss for current examples.

Return type

torch.Tensor

training: bool
speechbrain.nnet.losses.ctc_loss_kd(log_probs, targets, input_lens, blank_index, device)[source]

Knowledge distillation for CTC loss.

Distilling Knowledge from Ensembles of Acoustic Models for Joint CTC-Attention End-to-End Speech Recognition. https://arxiv.org/abs/2005.09310

Parameters
  • log_probs (torch.Tensor) – Predicted tensor from student model, of shape [batch, time, chars].

  • targets (torch.Tensor) – Predicted tensor from single teacher model, of shape [batch, time, chars].

  • input_lens (torch.Tensor) – Length of each utterance.

  • blank_index (int) – The location of the blank symbol among the character indexes.

  • device (str) – Device for computing.

speechbrain.nnet.losses.ce_kd(inp, target)[source]

Simple version of distillation for cross-entropy loss.

Parameters
  • inp (torch.Tensor) – The probabilities from student model, of shape [batch_size * length, feature]

  • target (torch.Tensor) – The probabilities from teacher model, of shape [batch_size * length, feature]

speechbrain.nnet.losses.nll_loss_kd(probabilities, targets, rel_lab_lengths)[source]

Knowledge distillation for negative log-likelihood loss.

Distilling Knowledge from Ensembles of Acoustic Models for Joint CTC-Attention End-to-End Speech Recognition. https://arxiv.org/abs/2005.09310

Parameters
  • probabilities (torch.Tensor) – The predicted probabilities from the student model. Format is [batch, frames, p]

  • targets (torch.Tensor) – The target probabilities from the teacher model. Format is [batch, frames, p]

  • rel_lab_lengths (torch.Tensor) – Length of each utterance, if the frame-level loss is desired.

Example

>>> probabilities = torch.tensor([[[0.8, 0.2], [0.2, 0.8]]])
>>> targets = torch.tensor([[[0.9, 0.1], [0.1, 0.9]]])
>>> rel_lab_lengths = torch.tensor([1.])
>>> nll_loss_kd(probabilities, targets, rel_lab_lengths)
tensor(-0.7400)