"""Graph compiler class to create, store, and use k2 decoding graphs in
speechbrain. Limits the output words to the ones in the lexicon.
This code is an extension, and therefore heavily inspired or taken from
icefall's (https://github.com/k2-fsa/icefall) graph compiler.
Authors:
* Pierre Champion 2023
* Zeyu Zhao 2023
* Georgios Karakasidis 2023
"""
import abc
import os
from typing import List, Optional, Tuple
import torch
from speechbrain.utils.logger import get_logger
from . import k2 # import k2 from ./__init__.py
from . import lexicon
logger = get_logger(__name__)
[docs]
class GraphCompiler(abc.ABC):
"""
This abstract class is used to compile graphs for training and decoding.
"""
@abc.abstractproperty
def topo(self) -> k2.Fsa:
"""
Return the topology used to compile the graph.
"""
pass
@abc.abstractproperty
def lexicon(self) -> lexicon.Lexicon:
"""
Return the lexicon used to compile the graph.
"""
pass
@abc.abstractproperty
def device(self):
"""
Return the device used to compile the graph.
"""
pass
[docs]
@abc.abstractmethod
def compile(
self, texts: List[str], is_training: bool = True
) -> Tuple[k2.Fsa, torch.Tensor]:
"""
Compile the graph for the given texts.
Arguments
---------
texts: List[str]
A list of strings. Each string contains a sentence for an utterance.
A sentence consists of spaces separated words. An example `texts`
looks like:
['hello world', 'CTC training with k2']
is_training: bool
Indictating whether this is for training or not
(OOV warning in training).
Returns
-------
graph: GraphCompiler
An FsaVec, the composition result of `self.ctc_topo` and the
transcript FSA.
target_lens: Torch.tensor
It is an long tensor of shape (batch,). It contains lengths of
each target sequence.
"""
pass
[docs]
def compile_HL(self, cache_dir: Optional[str] = None, cache: bool = False):
"""
Compile the decoding graph by composing H with L.
This is for decoding without language model.
Arguments
---------
cache_dir: str
The path to store the composition in a .pt format.
cache: bool
Whether or not to load the composition from the .pt format (in the
cache_dir dir).
Returns
-------
HL: k2.Fsa
The HL composition
"""
logger.info("Arc sorting L")
L = k2.arc_sort(self.lexicon.L).to("cpu")
H = self.topo.to("cpu")
file_hash = str(hash(H.shape[0])) + str(hash(L.shape[0]))
if cache and cache_dir is not None:
path = cache_dir + "/.HL_" + file_hash + ".pt"
if os.path.exists(path):
logger.warning(
f"Loading HL '{path}' from its cached .pt format."
" Set 'caching: False' in the yaml"
" if this is not what you want."
)
HL = k2.Fsa.from_dict(torch.load(path, map_location="cpu"))
return HL
logger.info("Composing H and L")
HL = k2.compose(H, L, inner_labels="tokens")
logger.info("Connecting HL")
HL = k2.connect(HL)
logger.info("Arc sorting HL")
HL = k2.arc_sort(HL)
logger.debug(f"HL.shape: {HL.shape}")
if cache_dir is not None:
path = cache_dir + "/.HL_" + file_hash + ".pt"
logger.info("Caching HL to: " + path)
torch.save(HL.as_dict(), path)
return HL
[docs]
def compile_HLG(
self, G, cache_dir: Optional[str] = None, cache: bool = False
):
"""
Compile the decoding graph by composing H with LG.
This is for decoding with small language model.
Arguments
---------
G: k2.Fsa
The language model FSA.
cache_dir: str
The path to store the composition in a .pt format.
cache: bool
Whether or not to load the composition from the .pt format (in the
cache_dir dir).
Returns
-------
HL: k2.Fsa
The HLG composition
"""
logger.info("Arc sorting L")
L = k2.arc_sort(self.lexicon.L_disambig).to("cpu")
G = k2.arc_sort(G).to("cpu")
H = self.topo.to("cpu")
file_hash = (
str(hash(H.shape[0]))
+ str(hash(L.shape[0]))
+ str(hash(G.shape[0]))
)
if cache and cache_dir is not None:
path = cache_dir + "/.HLG_" + file_hash + ".pt"
if os.path.exists(path):
logger.warning(
f"Loading HLG '{path}' from its cached .pt format."
" Set 'caching: False' in the yaml"
" if this is not what you want."
)
HLG = k2.Fsa.from_dict(torch.load(path, map_location="cpu"))
return HLG
logger.info("Intersecting L and G")
LG = k2.compose(L, G)
logger.info("Connecting LG")
LG = k2.connect(LG)
logger.info("Determinizing LG")
LG = k2.determinize(LG)
logger.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
LG = self.lexicon.remove_LG_disambig_symbols(LG)
LG = k2.remove_epsilon(LG)
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logger.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logger.info("Composing H and LG")
HLG = k2.compose(H, LG, inner_labels="tokens")
logger.info("Connecting HLG")
HLG = k2.connect(HLG)
logger.info("Arc sorting HLG")
HLG = k2.arc_sort(HLG)
logger.debug(f"HLG.shape: {HLG.shape}")
if cache_dir is not None:
path = cache_dir + "/.HLG_" + file_hash + ".pt"
logger.info("Caching HLG to: " + path)
torch.save(HLG.as_dict(), path)
return HLG
[docs]
class CtcGraphCompiler(GraphCompiler):
"""
This class is used to compile decoding graphs for CTC training.
Arguments
---------
_lexicon: Lexicon
It is built from `data/lang/lexicon.txt`.
device: torch.device
The device to use for operations compiling transcripts to FSAs.
need_repeat_flag: bool
If True, will add an attribute named `_is_repeat_token_` to ctc_topo
indicating whether this token is a repeat token in ctc graph.
This attribute is needed to implement delay-penalty for phone-based
ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
details. Note: The above change MUST be included in k2 to enable this
flag so make sure you have an up-to-date version.
Example
-------
>>> import torch
>>> from speechbrain.k2_integration.losses import ctc_k2
>>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
>>> from speechbrain.k2_integration.lexicon import Lexicon
>>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> # Create a random batch of log-probs
>>> batch_size = 4
>>> log_probs = torch.randn(batch_size, 100, 30)
>>> log_probs.requires_grad = True
>>> # Assume all utterances have the same length so no padding was needed.
>>> input_lens = torch.ones(batch_size)
>>> # Create a small lexicon containing only two words and write it to a file.
>>> lang_tmpdir = getfixture('tmpdir')
>>> lexicon_sample = "hello h e l l o\\nworld w o r l d\\n<UNK> <unk>"
>>> lexicon_file = lang_tmpdir.join("lexicon.txt")
>>> lexicon_file.write(lexicon_sample)
>>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
>>> prepare_lang(lang_tmpdir)
>>> # Create a lexicon object
>>> lexicon = Lexicon(lang_tmpdir)
>>> # Create a random decoding graph
>>> graph = CtcGraphCompiler(
... lexicon,
... log_probs.device,
... )
>>> isinstance(graph.topo, k2.Fsa)
True
"""
def __init__(
self,
_lexicon: lexicon.Lexicon,
device: torch.device,
need_repeat_flag: bool = False,
):
self._device = device
self._lexicon = _lexicon
self.lexicon.to(device)
assert self.lexicon.L_inv.requires_grad is False
self.lexicon.arc_sort()
max_token_id = max(self.lexicon.tokens)
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
self.ctc_topo = ctc_topo.to(device)
if need_repeat_flag:
self.ctc_topo._is_repeat_token_ = (
self.ctc_topo.labels != self.ctc_topo.aux_labels
)
@property
def topo(self):
"""
Return the ctc_topo.
"""
return self.ctc_topo
@property
def lexicon(self):
"""
Return the lexicon.
"""
return self._lexicon
@property
def device(self):
"""Return the device used for compiling graphs."""
return self._device
[docs]
def compile(
self, texts: List[str], is_training: bool = True
) -> Tuple[k2.Fsa, torch.Tensor]:
"""
Build decoding graphs by composing ctc_topo with given transcripts.
Arguments
---------
texts: List[str]
A list of strings. Each string contains a sentence for an utterance.
A sentence consists of spaces separated words. An example `texts`
looks like:
['hello world', 'CTC training with k2']
is_training: bool
Indictating whether this is for training or not
(OOV warning in training).
Returns
-------
graph: GraphCompiler
An FsaVec, the composition result of `self.ctc_topo` and the
transcript FSA.
target_lens: Torch.tensor
It is an long tensor of shape (batch,). It contains lengths of
each target sequence.
"""
word_idx = self.lexicon.texts_to_word_ids(
texts, log_unknown_warning=is_training
)
# ["test", "testa"] -> [[23, 8, 22, 23], [23, 8, 22, 23, 5]] -> [4, 5]
word2tids = self.lexicon.texts_to_token_ids(
texts, log_unknown_warning=is_training
)
sentence_ids = [sum(inner, []) for inner in word2tids]
target_lens = torch.tensor(
[len(t) for t in sentence_ids], dtype=torch.long
)
word_fsa_with_self_loops = k2.add_epsilon_self_loops(
k2.linear_fsa(word_idx, self.device)
)
fsa = k2.intersect(
self.lexicon.L_inv,
word_fsa_with_self_loops,
treat_epsilons_specially=False,
)
# fsa has word ID as labels and token ID as aux_labels, so
# we need to invert it
ans_fsa = fsa.invert_()
transcript_fsa = k2.arc_sort(ans_fsa)
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
# is False, so we add epsilon self-loops here
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
transcript_fsa
)
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
graph = k2.compose(
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
)
assert graph.requires_grad is False
return graph, target_lens