""" This file contains the loss functions for k2 training. Currently, we only
support CTC loss.
Authors:
* Pierre Champion 2023
* Zeyu Zhao 2023
* Georgios Karakasidis 2023
"""
import torch
from . import k2 # import k2 from ./__init__.py
[docs]
def ctc_k2(
log_probs,
input_lens,
graph_compiler,
texts,
reduction="mean",
beam_size=10,
use_double_scores=True,
is_training=True,
):
"""
CTC loss implemented with k2. Make sure that k2 has been installed properly.
Note that the blank index must be 0 in this implementation.
Arguments
---------
log_probs: torch.Tensor
Log-probs of shape (batch, time, num_classes).
input_lens : torch.Tensor
Length of each utterance.
graph_compiler : k2.Fsa
Decoding graph.
texts : List[str]
List of texts.
reduction : str
What reduction to apply to the output. 'mean', 'sum', 'none'.
See k2.ctc_loss for 'mean', 'sum', 'none'.
beam_size : int
Beam size.
use_double_scores : bool
If true, use double precision for scores.
is_training : bool
If true, the returned loss requires gradient.
Returns
-------
loss: torch.Tensor
CTC loss.
Example
-------
>>> import torch
>>> from speechbrain.k2_integration.losses import ctc_k2
>>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
>>> from speechbrain.k2_integration.lexicon import Lexicon
>>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs
>>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30)
>>> log_probs.requires_grad = True
>>> # Assume all utterances have the same length so no padding was needed.
>>> input_lens = torch.ones(batch_size)
>>> # Create a small lexicon containing only two words and write it to a file.
>>> lang_tmpdir = getfixture('tmpdir')
>>> lexicon_sample = "hello h e l l o\\nworld w o r l d\\n<UNK> <unk>"
>>> lexicon_file = lang_tmpdir.join("lexicon.txt")
>>> lexicon_file.write(lexicon_sample)
>>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
>>> prepare_lang(lang_tmpdir)
>>> # Create a lexicon object
>>> lexicon = Lexicon(lang_tmpdir)
>>> # Create a random decoding graph
>>> graph = CtcGraphCompiler(
... lexicon,
... log_probs.device,
... )
>>> # Create a random batch of texts
>>> texts = ["hello world", "world hello", "hello", "world"]
>>> # Compute the loss
>>> loss = ctc_k2(
... log_probs=log_probs,
... input_lens=input_lens,
... graph_compiler=graph,
... texts=texts,
... reduction="mean",
... beam_size=10,
... use_double_scores=True,
... is_training=True,
... )
"""
input_lens = (input_lens * log_probs.shape[1]).round().int()
batch_size = log_probs.shape[0]
supervision_segments = torch.tensor(
[[i, 0, input_lens[i]] for i in range(batch_size)],
device="cpu",
dtype=torch.int32,
)
decoding_graph, target_lens = graph_compiler.compile(
texts, is_training=is_training
)
# An introduction to DenseFsaVec:
# https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector
# It could be viewed as a fsa-type log_probs,
# whose weight on the arcs are initialized with log_probs.
# The goal of converting tensor-type to fsa-type is using
# fsa related functions in k2. e.g. k2.ctc_loss.
dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
loss = k2.ctc_loss(
decoding_graph=decoding_graph.to(log_probs.device),
dense_fsa_vec=dense_fsa_vec,
target_lengths=target_lens.to(log_probs.device),
output_beam=beam_size,
reduction=reduction,
use_double_scores=use_double_scores,
)
assert loss.requires_grad == is_training
return loss