speechbrain.k2_integration.graph_compiler module

Graph compiler class to create, store, and use k2 decoding graphs in speechbrain. Limits the output words to the ones in the lexicon.

This code is an extension, and therefore heavily inspired or taken from icefall’s (https://github.com/k2-fsa/icefall) graph compiler.

Authors:
  • Pierre Champion 2023

  • Zeyu Zhao 2023

  • Georgios Karakasidis 2023

Summary

Classes:

CtcGraphCompiler

This class is used to compile decoding graphs for CTC training.

GraphCompiler

This abstract class is used to compile graphs for training and decoding.

Reference

class speechbrain.k2_integration.graph_compiler.GraphCompiler[source]

Bases: ABC

This abstract class is used to compile graphs for training and decoding.

abstract property topo: Fsa

Return the topology used to compile the graph.

abstract property lexicon: Lexicon

Return the lexicon used to compile the graph.

abstract property device

Return the device used to compile the graph.

abstract compile(texts: List[str], is_training: bool = True) Tuple[Fsa, Tensor][source]

Compile the graph for the given texts.

Parameters:
  • texts (List[str]) –

    A list of strings. Each string contains a sentence for an utterance. A sentence consists of spaces separated words. An example texts looks like:

    [‘hello world’, ‘CTC training with k2’]

  • is_training (bool) – Indictating whether this is for training or not (OOV warning in training).

Returns:

  • graph (GraphCompiler) – An FsaVec, the composition result of self.ctc_topo and the transcript FSA.

  • target_lens (Torch.tensor) – It is an long tensor of shape (batch,). It contains lengths of each target sequence.

compile_HL(cache_dir: str | None = None, cache: bool = False)[source]

Compile the decoding graph by composing H with L. This is for decoding without language model.

Parameters:
  • cache_dir (str) – The path to store the composition in a .pt format.

  • cache (bool) – Whether or not to load the composition from the .pt format (in the cache_dir dir).

Returns:

HL – The HL composition

Return type:

k2.Fsa

compile_HLG(G, cache_dir: str | None = None, cache: bool = False)[source]

Compile the decoding graph by composing H with LG. This is for decoding with small language model.

Parameters:
  • G (k2.Fsa) – The language model FSA.

  • cache_dir (str) – The path to store the composition in a .pt format.

  • cache (bool) – Whether or not to load the composition from the .pt format (in the cache_dir dir).

Returns:

HL – The HLG composition

Return type:

k2.Fsa

class speechbrain.k2_integration.graph_compiler.CtcGraphCompiler(_lexicon: Lexicon, device: device, need_repeat_flag: bool = False)[source]

Bases: GraphCompiler

This class is used to compile decoding graphs for CTC training.

Parameters:
  • lexicon (Lexicon) – It is built from data/lang/lexicon.txt.

  • device (torch.device) – The device to use for operations compiling transcripts to FSAs.

  • need_repeat_flag (bool) – If True, will add an attribute named _is_repeat_token_ to ctc_topo indicating whether this token is a repeat token in ctc graph. This attribute is needed to implement delay-penalty for phone-based ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more details. Note: The above change MUST be included in k2 to enable this flag so make sure you have an up-to-date version.

Example

>>> import torch
>>> from speechbrain.k2_integration.losses import ctc_k2
>>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
>>> from speechbrain.k2_integration.lexicon import Lexicon
>>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs
>>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30)
>>> log_probs.requires_grad = True
>>> # Assume all utterances have the same length so no padding was needed.
>>> input_lens = torch.ones(batch_size)
>>> # Create a samll lexicon containing only two words and write it to a file.
>>> lang_tmpdir = getfixture('tmpdir')
>>> lexicon_sample = "hello h e l l o\nworld w o r l d\n<UNK> <unk>"
>>> lexicon_file = lang_tmpdir.join("lexicon.txt")
>>> lexicon_file.write(lexicon_sample)
>>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
>>> prepare_lang(lang_tmpdir)
>>> # Create a lexicon object
>>> lexicon = Lexicon(lang_tmpdir)
>>> # Create a random decoding graph
>>> graph = CtcGraphCompiler(
...     lexicon,
...     log_probs.device,
... )
>>> isinstance(graph.topo, k2.Fsa)
True
property topo

Return the ctc_topo.

property lexicon

Return the lexicon.

property device

Return the device used for compiling graphs.

compile(texts: List[str], is_training: bool = True) Tuple[Fsa, Tensor][source]

Build decoding graphs by composing ctc_topo with given transcripts.

Parameters:
  • texts (List[str]) –

    A list of strings. Each string contains a sentence for an utterance. A sentence consists of spaces separated words. An example texts looks like:

    [‘hello world’, ‘CTC training with k2’]

  • is_training (bool) – Indictating whether this is for training or not (OOV warning in training).

Returns:

  • graph (GraphCompiler) – An FsaVec, the composition result of self.ctc_topo and the transcript FSA.

  • target_lens (Torch.tensor) – It is an long tensor of shape (batch,). It contains lengths of each target sequence.