speechbrain.k2_integration.losses module
This file contains the loss functions for k2 training. Currently, we only support CTC loss.
- Authors:
Pierre Champion 2023
Zeyu Zhao 2023
Georgios Karakasidis 2023
Summary
Functions:
CTC loss implemented with k2. |
Reference
- speechbrain.k2_integration.losses.ctc_k2(log_probs, input_lens, graph_compiler, texts, reduction='mean', beam_size=10, use_double_scores=True, is_training=True)[source]
CTC loss implemented with k2. Make sure that k2 has been installed properly. Note that the blank index must be 0 in this implementation.
- Parameters:
log_probs (torch.Tensor) – Log-probs of shape (batch, time, num_classes).
input_lens (torch.Tensor) – Length of each utterance.
graph_compiler (k2.Fsa) – Decoding graph.
texts (List[str]) – List of texts.
reduction (str) – What reduction to apply to the output. ‘mean’, ‘sum’, ‘none’. See k2.ctc_loss for ‘mean’, ‘sum’, ‘none’.
beam_size (int) – Beam size.
use_double_scores (bool) – If true, use double precision for scores.
is_training (bool) – If true, the returned loss requires gradient.
- Returns:
loss – CTC loss.
- Return type:
torch.Tensor
Example
>>> import torch >>> from speechbrain.k2_integration.losses import ctc_k2 >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler >>> from speechbrain.k2_integration.lexicon import Lexicon >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs >>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30) >>> log_probs.requires_grad = True >>> # Assume all utterances have the same length so no padding was needed. >>> input_lens = torch.ones(batch_size) >>> # Create a small lexicon containing only two words and write it to a file. >>> lang_tmpdir = getfixture('tmpdir') >>> lexicon_sample = "hello h e l l o\nworld w o r l d\n<UNK> <unk>" >>> lexicon_file = lang_tmpdir.join("lexicon.txt") >>> lexicon_file.write(lexicon_sample) >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt >>> prepare_lang(lang_tmpdir) >>> # Create a lexicon object >>> lexicon = Lexicon(lang_tmpdir) >>> # Create a random decoding graph >>> graph = CtcGraphCompiler( ... lexicon, ... log_probs.device, ... ) >>> # Create a random batch of texts >>> texts = ["hello world", "world hello", "hello", "world"] >>> # Compute the loss >>> loss = ctc_k2( ... log_probs=log_probs, ... input_lens=input_lens, ... graph_compiler=graph, ... texts=texts, ... reduction="mean", ... beam_size=10, ... use_double_scores=True, ... is_training=True, ... )