"""Different decoding graph algorithms for k2, be it HL or HLG (with G LM
and bigger rescoring LM).
This code was adjusted from icefall (https://github.com/k2-fsa/icefall/blob/master/icefall/decode.py).
Authors:
* Pierre Champion 2023
* Zeyu Zhao 2023
* Georgios Karakasidis 2023
"""
from pathlib import Path
from typing import Dict, List, Optional, Union
from collections import OrderedDict
from . import k2 # import k2 from ./__init__.py
from speechbrain.utils.distributed import run_on_main
from speechbrain.lm.arpa import arpa_to_fst
import torch
import logging
from . import graph_compiler, utils
logger = logging.getLogger(__name__)
[docs]
def get_decoding(
hparams: Dict, graphCompiler: graph_compiler.GraphCompiler, device="cpu"
):
"""
This function reads a config and creates the decoder for k2 graph compiler
decoding.
There are the following cases:
- HLG is compiled and LM rescoring is used. In that case,
compose_HL_with_G and use_G_rescoring are both True and we will
create for example G_3_gram.fst.txt and G_4_gram.fst.txt. Note that
the 3gram and 4gram ARPA lms will need to exist under
`hparams['lm_dir']`.
- HLG is compiled but LM rescoring is not used. In that case,
compose_HL_with_G is True and use_G_rescoring is False and we will
create for example G_3_gram.fst.txt. Note that the 3gram ARPA lm will
need to exist under `hparams['lm_dir']`.
- HLG is not compiled (only use HL graph) and LM rescoring used.
In that case, compose_HL_with_G is False and use_G_rescoring is True.
Note that the 4gram ARPA lms will need to exist under
`hparams['lm_dir']`.
- HLG is not compiled (only use HL graph) and LM rescoring is not used.
In that case, compose_HL_with_G is False and use_G_rescoring is False
and we will not convert LM to FST.
Arguments
---------
hparams: dict
The hyperparameters.
graphCompiler: graph_compiler.GraphCompiler
The graphCompiler (H)
device : torch.device
The device to use.
Returns
-------
Dict:
decoding_graph: k2.Fsa
A HL or HLG decoding graph.
Used with a nnet output and the function `get_lattice` to
obtain a decoding lattice `k2.Fsa`.
decoding_method: Callable[[k2.Fsa], k2.Fsa]
A function to call with a decoding lattice `k2.Fsa` (obtained
after nnet output intersect with a HL or HLG).
Retuns an FsaVec containing linear FSAs
Example
-------
>>> import torch
>>> from speechbrain.k2_integration.losses import ctc_k2
>>> from speechbrain.k2_integration.utils import lattice_paths_to_text
>>> from speechbrain.k2_integration.graph_compiler import CtcGraphCompiler
>>> from speechbrain.k2_integration.lexicon import Lexicon
>>> from speechbrain.k2_integration.prepare_lang import prepare_lang
>>> from speechbrain.k2_integration.lattice_decoder import get_decoding
>>> from speechbrain.k2_integration.lattice_decoder import get_lattice
>>> batch_size = 1
>>> log_probs = torch.randn(batch_size, 40, 10)
>>> log_probs.requires_grad = True
>>> # Assume all utterances have the same length so no padding was needed.
>>> input_lens = torch.ones(batch_size)
>>> # Create a samll lexicon containing only two words and write it to a file.
>>> lang_tmpdir = getfixture('tmpdir')
>>> lexicon_sample = "hello h e l l o\\nworld w o r l d\\n<UNK> <unk>"
>>> lexicon_file = lang_tmpdir.join("lexicon.txt")
>>> lexicon_file.write(lexicon_sample)
>>> # Create a lang directory with the lexicon and L.pt, L_inv.pt, L_disambig.pt
>>> prepare_lang(lang_tmpdir)
>>> # Create a lexicon object
>>> lexicon = Lexicon(lang_tmpdir)
>>> # Create a random decoding graph
>>> graph = CtcGraphCompiler(
... lexicon,
... log_probs.device,
... )
>>> decode = get_decoding(
... {"compose_HL_with_G": False,
... "decoding_method": "onebest",
... "lang_dir": lang_tmpdir},
... graph)
>>> lattice = get_lattice(log_probs, input_lens, decode["decoding_graph"])
>>> path = decode["decoding_method"](lattice)['1best']
>>> text = lattice_paths_to_text(path, lexicon.word_table)
"""
compose_HL_with_G = hparams.get("compose_HL_with_G")
use_G_rescoring = (
hparams.get("decoding_method") == "whole-lattice-rescoring"
)
caching = (
False if "caching" in hparams and hparams["caching"] is False else True
)
if compose_HL_with_G or use_G_rescoring:
lm_dir = Path(hparams["lm_dir"])
G_path = lm_dir / (hparams["G_arpa"].replace("arpa", "fst.txt"))
G_rescoring_path = (
lm_dir / (hparams["G_rescoring_arpa"].replace("arpa", "fst.txt"))
if use_G_rescoring
else None
)
if compose_HL_with_G:
run_on_main(
arpa_to_fst,
kwargs={
"words_txt": Path(hparams["lang_dir"]) / "words.txt",
"in_arpa": lm_dir / hparams["G_arpa"],
"out_fst": G_path,
"ngram_order": 3, # by default use 3-gram for HLG's LM
"cache": caching,
},
)
if use_G_rescoring:
run_on_main(
arpa_to_fst,
kwargs={
"words_txt": Path(hparams["lang_dir"]) / "words.txt",
"in_arpa": lm_dir / hparams["G_rescoring_arpa"],
"out_fst": G_rescoring_path,
"ngram_order": 4, # by default use 4-gram for rescoring LM
"cache": caching,
},
)
output_folder = None
if "output_folder" in hparams:
output_folder = output_folder
if compose_HL_with_G:
G = utils.load_G(G_path, cache=caching)
decoding_graph = graphCompiler.compile_HLG(
G, cache_dir=output_folder, cache=caching
)
else:
decoding_graph = graphCompiler.compile_HL(
cache_dir=output_folder, cache=caching
)
if hparams.get("decoding_method") == "whole-lattice-rescoring":
G_rescoring = None
if not isinstance(hparams["rescoring_lm_scale"], list):
hparams["rescoring_lm_scale"] = [hparams["rescoring_lm_scale"]]
def decoding_method(lattice: k2.Fsa) -> Dict[str, k2.Fsa]:
"""Get the best path from a lattice given rescoring_lm_scale."""
# Lazy load rescoring G (takes a lot of time) for developer happiness
nonlocal G_rescoring
if G_rescoring is None:
logger.info("Decoding method: whole-lattice-rescoring")
logger.info(f"Loading rescoring LM: {G_rescoring_path}")
G_rescoring_pt = utils.load_G(G_rescoring_path, cache=caching)
graphCompiler.lexicon.remove_G_rescoring_disambig_symbols(
G_rescoring_pt
)
G_rescoring = utils.prepare_rescoring_G(G_rescoring_pt)
# rescore_with_whole_lattice returns a list of paths depending on
# lm_scale values.
return rescore_with_whole_lattice(
lattice,
G_rescoring,
lm_scale_list=hparams["rescoring_lm_scale"],
)
elif hparams.get("decoding_method") in ["1best", "onebest"]:
logger.info("Decoding method: one-best-decoding")
def decoding_method(lattice: k2.Fsa) -> Dict[str, k2.Fsa]:
"""Get the best path from a lattice."""
return OrderedDict({"1best": one_best_decoding(lattice)})
else:
def decoding_method(lattice: k2.Fsa):
"""A dummy decoding method that raises an error."""
raise NotImplementedError(
f"{hparams.get('decoding_method')} not implemented as a decoding_method"
)
return {
"decoding_graph": decoding_graph.to(device),
"decoding_method": decoding_method,
}
[docs]
@torch.no_grad()
def get_lattice(
log_probs_nnet_output: torch.Tensor,
input_lens: torch.Tensor,
decoder: k2.Fsa,
search_beam: int = 5,
output_beam: int = 5,
min_active_states: int = 300,
max_active_states: int = 1000,
ac_scale: float = 1.0,
subsampling_factor: int = 1,
) -> k2.Fsa:
"""
Get the decoding lattice from a decoding graph and neural network output.
Arguments
---------
log_probs_nnet_output:
It is the output of a neural model of shape `(batch, seq_len, num_tokens)`.
input_lens:
It is an int tensor of shape (batch,). It contains lengths of
each sequence in `log_probs_nnet_output`.
decoder:
It is an instance of :class:`k2.Fsa` that represents the decoding graph.
search_beam:
Decoding beam, e.g. 20. Ger is faster, larger is more exact
(less pruning). This is the default value; it may be modified by
`min_active_states` and `max_active_states`.
output_beam:
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
min_active_states:
Minimum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to have fewer than this number active.
Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
ac_scale:
acoustic scale applied to `log_probs_nnet_output`
subsampling_factor:
The subsampling factor of the model.
Returns
-------
lattice:
An FsaVec containing the decoding result. It has axes [utt][state][arc].
"""
device = log_probs_nnet_output.device
input_lens = input_lens.to(device)
if decoder.device != device:
logger.warn(
"Decoding graph (HL or HLG) not loaded on the same device"
" as nnet, this will cause decoding speed degradation"
)
decoder = decoder.to(device)
input_lens = (input_lens * log_probs_nnet_output.shape[1]).round().int()
# NOTE: low ac_scales may results in very big lattices and OOM errors.
log_probs_nnet_output *= ac_scale
lattice = k2.get_lattice(
log_probs_nnet_output,
input_lens,
decoder,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states,
subsampling_factor=subsampling_factor,
)
return lattice
[docs]
@torch.no_grad()
def one_best_decoding(
lattice: k2.Fsa, use_double_scores: bool = True,
) -> k2.Fsa:
"""
Get the best path from a lattice.
Arguments
---------
lattice:
The decoding lattice returned by :func:`get_lattice`.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
Returns
-------
best_path:
An FsaVec containing linear paths.
"""
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path
[docs]
@torch.no_grad()
def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""
Intersect the lattice with an n-gram LM and use shortest path to decode.
The input lattice is obtained by intersecting `HLG` with
a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM.
The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider
this function as a second pass decoding. In the first pass decoding, we
use a small G, while we use a larger G in the second pass decoding.
Arguments
---------
lattice: k2.Fsa
An FsaVec with axes [utt][state][arc]. Its `aux_labels` are word IDs.
It must have an attribute `lm_scores`.
G_with_epsilon_loops: k2.Fsa
An FsaVec containing only a single FSA. It contains epsilon self-loops.
It is an acceptor and its labels are word IDs.
lm_scale_list: Optional[List[float]]
If none, return the intersection of `lattice` and `G_with_epsilon_loops`.
If not None, it contains a list of values to scale LM scores.
For each scale, there is a corresponding decoding result contained in
the resulting dict.
use_double_scores: bool
True to use double precision in the computation.
False to use single precision.
Returns
-------
If `lm_scale_list` is None, return a new lattice which is the intersection
result of `lattice` and `G_with_epsilon_loops`.
Otherwise, return a dict whose key is an entry in `lm_scale_list` and the
value is the decoding result (i.e., an FsaVec containing linear FSAs).
"""
assert G_with_epsilon_loops.shape == (1, None, None)
G_with_epsilon_loops = G_with_epsilon_loops.to(lattice.device)
device = lattice.device
if hasattr(lattice, "lm_scores"):
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
# Its `aux_labels` is token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
# NOTE: The choice of the threshold list is arbitrary here to avoid OOM.
# You may need to fine tune it.
prune_th_list = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6]
prune_th_list += [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
try:
if device == "cpu":
rescoring_lattice = k2.intersect(
G_with_epsilon_loops,
inv_lattice,
treat_epsilons_specially=True,
)
else:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
inv_lattice,
b_to_a_map,
sorted_match_a=True,
)
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
break
except RuntimeError as e:
logger.info(f"Caught exception:\n{e}\n")
if loop_count >= max_loop_count:
logger.info(
"Return None as the resulting lattice is too large."
)
return None
logger.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
logger.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or the segment length "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
inv_lattice = k2.prune_on_arc_post(
inv_lattice, prune_th_list[loop_count], True,
)
logger.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
loop_count += 1
# lat has token IDs as labels
# and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = OrderedDict()
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"whole_lattice_rescore_lm_scale_{lm_scale:.1f}"
ans[key] = best_path
return ans