speechbrain.k2_integration.losses module

This file contains the loss functions for k2 training. Currently, we only support CTC loss.

Authors:
  • Pierre Champion 2023

  • Zeyu Zhao 2023

  • Georgios Karakasidis 2023

Summary

Functions:

ctc_k2

CTC loss implemented with k2.

Reference

speechbrain.k2_integration.losses.ctc_k2(log_probs, input_lens, graph_compiler, texts, reduction='mean', beam_size=10, use_double_scores=True, is_training=True)[source]

CTC loss implemented with k2. Make sure that k2 has been installed properly. Note that the blank index must be 0 in this implementation.

Parameters:
  • log_probs (torch.Tensor) – Log-probs of shape (batch, time, num_classes).

  • input_lens (torch.Tensor) – Length of each utterance.

  • graph_compiler (k2.Fsa) – Decoding graph.

  • texts (List[str]) – List of texts.

  • reduction (str) – What reduction to apply to the output. ‘mean’, ‘sum’, ‘none’. See k2.ctc_loss for ‘mean’, ‘sum’, ‘none’.

  • beam_size (int) – Beam size.

  • use_double_scores (bool) – If true, use double precision for scores.

  • is_training (bool) – If true, the returned loss requires gradient.

Returns:

loss – CTC loss.

Return type:

torch.Tensor

Example

>>> import torch
>>> from speechbrain.k2_integration.losses import ctc_k2
>>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
>>> from speechbrain.k2_integration.lexicon import Lexicon
>>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs
>>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30)
>>> log_probs.requires_grad = True
>>> # Assume all utterances have the same length so no padding was needed.
>>> input_lens = torch.ones(batch_size)
>>> # Create a samll lexicon containing only two words and write it to a file.
>>> lang_tmpdir = getfixture('tmpdir')
>>> lexicon_sample = "hello h e l l o\nworld w o r l d\n<UNK> <unk>"
>>> lexicon_file = lang_tmpdir.join("lexicon.txt")
>>> lexicon_file.write(lexicon_sample)
>>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
>>> prepare_lang(lang_tmpdir)
>>> # Create a lexicon object
>>> lexicon = Lexicon(lang_tmpdir)
>>> # Create a random decoding graph
>>> graph = CtcGraphCompiler(
...     lexicon,
...     log_probs.device,
... )
>>> # Create a random batch of texts
>>> texts = ["hello world", "world hello", "hello", "world"]
>>> # Compute the loss
>>> loss = ctc_k2(
...     log_probs=log_probs,
...     input_lens=input_lens,
...     graph_compiler=graph,
...     texts=texts,
...     reduction="mean",
...     beam_size=10,
...     use_double_scores=True,
...     is_training=True,
... )