Source code for speechbrain.k2_integration.utils

"""Utilities for k2 integration with SpeechBrain.

This code was adjusted from icefall (https://github.com/k2-fsa/icefall).


Authors:
  * Pierre Champion 2023
  * Zeyu Zhao 2023
  * Georgios Karakasidis 2023
"""

import os
import logging
from pathlib import Path
from typing import List, Union
import torch

from . import k2  # import k2 from ./__init__.py

logger = logging.getLogger(__name__)


[docs] def lattice_path_to_textid( best_paths: k2.Fsa, return_ragged: bool = False ) -> Union[List[List[int]], k2.RaggedTensor]: """ Extract the texts (as word IDs) from the best-path FSAs. Arguments --------- best_paths: k2.Fsa A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). return_ragged: bool True to return a ragged tensor with two axes [utt][word_id]. False to return a list-of-list word IDs. Returns ------- Returns a list of lists of int, containing the label sequences we decoded. """ if isinstance(best_paths.aux_labels, k2.RaggedTensor): # remove 0's and -1's. aux_labels = best_paths.aux_labels.remove_values_leq(0) # TODO: change arcs.shape() to arcs.shape aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) # remove the states and arcs axes. aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1) aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) else: # remove axis corresponding to states. aux_shape = best_paths.arcs.shape().remove_axis(1) aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. aux_labels = aux_labels.remove_values_leq(0) assert aux_labels.num_axes == 2 if return_ragged: return aux_labels else: return aux_labels.tolist()
[docs] def lattice_paths_to_text(best_paths: k2.Fsa, word_table) -> List[str]: """ Convert the best path to a list of strings. Arguments --------- best_paths: k2.Fsa It is the path in the lattice with the highest score for a given utterance. word_table: List[str] or Dict[int,str] It is a list or dict that maps word IDs to words. Returns ------- texts: List[str] A list of strings, each of which is the decoding result of the corresponding utterance. """ hyps: List[List[int]] = lattice_path_to_textid( best_paths, return_ragged=False ) texts = [] for wids in hyps: texts.append(" ".join([word_table[wid] for wid in wids])) return texts
[docs] def load_G(path: Union[str, Path], cache: bool = True) -> k2.Fsa: """ load a lm to be used in the decoding graph creation (or lm rescoring). Arguments --------- path: str The path to an FST LM (ending with .fst.txt) or a k2-converted LM (in pytorch .pt format). cache: bool Whether or not to load/cache the LM from/to the .pt format (in the same dir). Returns ------- G: k2.Fsa An FSA representing the LM. """ path = str(path) if os.path.exists(path.replace(".fst.txt", ".pt")) and cache: logger.warning( f"Loading '{path}' from its cached .pt format." " Set 'caching: False' in the yaml" " if this is not what you want." ) G = k2.Fsa.from_dict( torch.load(path.replace(".fst.txt", ".pt"), map_location="cpu") ) return G logger.info(f"Loading G LM: {path}") # If G_path is an fst.txt file then convert to .pt file if not os.path.isfile(path): raise FileNotFoundError( f"File {path} not found. " "You need to run arpa_to_fst to get it." ) with open(path) as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) torch.save(G.as_dict(), path[:-8] + ".pt") return G
[docs] def prepare_rescoring_G(G: k2.Fsa) -> k2.Fsa: """ Prepare a LM with the purpose of using it for LM rescoring. For instance, in the librispeech recipe this is a 4-gram LM (while a 3gram LM is used for HLG construction). Arguments --------- G: k2.Fsa An FSA representing the LM. Returns ------- G: k2.Fsa An FSA representing the LM, with the following modifications: - G.aux_labels is removed - G.lm_scores is set to G.scores - G is arc-sorted """ if "_properties" in G.__dict__: G.__dict__["_properties"] = None del G.aux_labels G = k2.Fsa.from_fsas([G]).to("cpu") # only used for decoding G = k2.arc_sort(G) G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) # G.lm_scores is used to replace HLG.lm_scores during LM rescoring. if not hasattr(G, "lm_scores"): G.lm_scores = G.scores.clone() return G