speechbrain.k2_integration.graph_compiler module
Graph compiler class to create, store, and use k2 decoding graphs in speechbrain. Limits the output words to the ones in the lexicon.
This code is an extension, and therefore heavily inspired or taken from icefall’s (https://github.com/k2-fsa/icefall) graph compiler.
- Authors:
Pierre Champion 2023
Zeyu Zhao 2023
Georgios Karakasidis 2023
Summary
Classes:
This class is used to compile decoding graphs for CTC training. |
|
This abstract class is used to compile graphs for training and decoding. |
Reference
- class speechbrain.k2_integration.graph_compiler.GraphCompiler[source]
Bases:
ABC
This abstract class is used to compile graphs for training and decoding.
- abstract property topo: Fsa
Return the topology used to compile the graph.
- abstract property device
Return the device used to compile the graph.
- abstract compile(texts: List[str], is_training: bool = True) Tuple[Fsa, Tensor] [source]
Compile the graph for the given texts.
- Parameters:
texts (List[str]) –
A list of strings. Each string contains a sentence for an utterance. A sentence consists of spaces separated words. An example
texts
looks like:[‘hello world’, ‘CTC training with k2’]
is_training (bool) – Indictating whether this is for training or not (OOV warning in training).
- Returns:
graph (GraphCompiler) – An FsaVec, the composition result of
self.ctc_topo
and the transcript FSA.target_lens (Torch.tensor) – It is an long tensor of shape (batch,). It contains lengths of each target sequence.
- compile_HL(cache_dir: str | None = None, cache: bool = False)[source]
Compile the decoding graph by composing H with L. This is for decoding without language model.
- class speechbrain.k2_integration.graph_compiler.CtcGraphCompiler(_lexicon: Lexicon, device: device, need_repeat_flag: bool = False)[source]
Bases:
GraphCompiler
This class is used to compile decoding graphs for CTC training.
- Parameters:
lexicon (Lexicon) – It is built from
data/lang/lexicon.txt
.device (torch.device) – The device to use for operations compiling transcripts to FSAs.
need_repeat_flag (bool) – If True, will add an attribute named
_is_repeat_token_
to ctc_topo indicating whether this token is a repeat token in ctc graph. This attribute is needed to implement delay-penalty for phone-based ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more details. Note: The above change MUST be included in k2 to enable this flag so make sure you have an up-to-date version.
Example
>>> import torch >>> from speechbrain.k2_integration.losses import ctc_k2 >>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler >>> from speechbrain.k2_integration.lexicon import Lexicon >>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs >>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30) >>> log_probs.requires_grad = True >>> # Assume all utterances have the same length so no padding was needed. >>> input_lens = torch.ones(batch_size) >>> # Create a samll lexicon containing only two words and write it to a file. >>> lang_tmpdir = getfixture('tmpdir') >>> lexicon_sample = "hello h e l l o\nworld w o r l d\n<UNK> <unk>" >>> lexicon_file = lang_tmpdir.join("lexicon.txt") >>> lexicon_file.write(lexicon_sample) >>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt >>> prepare_lang(lang_tmpdir) >>> # Create a lexicon object >>> lexicon = Lexicon(lang_tmpdir) >>> # Create a random decoding graph >>> graph = CtcGraphCompiler( ... lexicon, ... log_probs.device, ... ) >>> isinstance(graph.topo, k2.Fsa) True
- property topo
Return the ctc_topo.
- property lexicon
Return the lexicon.
- property device
Return the device used for compiling graphs.
- compile(texts: List[str], is_training: bool = True) Tuple[Fsa, Tensor] [source]
Build decoding graphs by composing ctc_topo with given transcripts.
- Parameters:
texts (List[str]) –
A list of strings. Each string contains a sentence for an utterance. A sentence consists of spaces separated words. An example
texts
looks like:[‘hello world’, ‘CTC training with k2’]
is_training (bool) – Indictating whether this is for training or not (OOV warning in training).
- Returns:
graph (GraphCompiler) – An FsaVec, the composition result of
self.ctc_topo
and the transcript FSA.target_lens (Torch.tensor) – It is an long tensor of shape (batch,). It contains lengths of each target sequence.