Source code for speechbrain.integrations.decoders.kenlm_scorer

"""Language model wrapper for kenlm n-gram.

This file is based on the implementation of the kenLM wrapper from
PyCTCDecode (see: https://github.com/kensho-technologies/pyctcdecode) and
is used in CTC decoders.

See: speechbrain.decoders.ctc

Authors
 * Adel Moumen 2023
 * Peter Plantinga 2024
"""

import math
from typing import Collection, Optional, Set, Tuple, cast

from pygtrie import CharTrie

from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)

try:
    import kenlm
except ImportError:
    raise ImportError(
        "kenlm python bindings are not installed. To install it use: "
        "pip install https://github.com/kpu/kenlm/archive/master.zip"
    )


[docs] def LanguageModel(*args, **kwargs): """This function redirects users to the correct class name, printing a deprecation notice. This can be removed once deprecation is complete. """ from warnings import warn warn( "The class name speechbrain.integrations.decoders.kenlm_decoder.LanguageModel " "is deprecated. Please use the updated name KenLMscorer" ) return KenlmScorer(*args, **kwargs)
[docs] def load_unigram_set_from_arpa(arpa_path: str) -> Set[str]: r"""Read unigrams from arpa file. Taken from: https://github.com/kensho-technologies/pyctcdecode Arguments --------- arpa_path : str Path to arpa file. Returns ------- unigrams : set Set of unigrams. Example ------- >>> arpa_file = getfixture("tmpdir").join("bigram.arpa") >>> arpa_file.write( ... "Anything can be here\n" ... + "\n" ... + "\\data\\\n" ... + "ngram 1=3\n" ... + "ngram 2=4\n" ... + "\n" ... + "\\1-grams:\n" ... + "0 <s>\n" ... + "-0.6931 a 0.\n" ... + "-0.6931 b 0.\n" ... + "" # Ends unigram section ... + "\\2-grams:\n" ... + "-0.6931 <s> a\n" ... + "-0.6931 a a\n" ... + "-0.6931 a b\n" ... + "-0.6931 b a\n" ... + "\n" # Ends bigram section ... + "\\end\\\n" ... ) # Ends whole file >>> sorted(load_unigram_set_from_arpa(arpa_file)) ['a', 'b'] """ unigrams = set() with open(arpa_path, encoding="utf-8") as f: start_1_gram = False for line in f: line = line.strip() if line == "\\1-grams:": start_1_gram = True elif line == "\\2-grams:": break if start_1_gram and len(line) > 0: parts = line.split() if len(parts) == 3: unigrams.add(parts[1]) if len(unigrams) == 0: raise ValueError( "No unigrams found in arpa file. Something is wrong with the file." ) return unigrams
[docs] class KenlmState: """Wrapper for kenlm state. This is a wrapper for the kenlm state object. It is used to make sure that the state is not modified outside of the language model class. Taken from: https://github.com/kensho-technologies/pyctcdecode Arguments --------- state : kenlm.State Kenlm state object. """ def __init__(self, state: "kenlm.State"): self._state = state @property def state(self) -> "kenlm.State": """Get the raw state object.""" return self._state
def _prepare_unigram_set( unigrams: Collection[str], kenlm_model: "kenlm.Model" ) -> Set[str]: """Filter unigrams down to vocabulary that exists in kenlm_model. Taken from: https://github.com/kensho-technologies/pyctcdecode Arguments --------- unigrams : list List of unigrams. kenlm_model : kenlm.Model Kenlm model. Returns ------- unigram_set : set Set of unigrams. """ if len(unigrams) < 1000: logger.warning( "Only %s unigrams passed as vocabulary. Is this small or artificial data?", len(unigrams), ) unigram_set = set(unigrams) unigram_set = set([t for t in unigram_set if t in kenlm_model]) retained_fraction = ( 1.0 if len(unigrams) == 0 else len(unigram_set) / len(unigrams) ) if retained_fraction < 0.1: logger.warning( "Only %s%% of unigrams in vocabulary found in kenlm model-- this might mean that your " "vocabulary and language model are incompatible. Is this intentional?", round(retained_fraction * 100, 1), ) return unigram_set def _get_empty_lm_state() -> "kenlm.State": """Get uninitialized kenlm state. Taken from: https://github.com/kensho-technologies/pyctcdecode Returns ------- kenlm_state : kenlm.State Empty kenlm state. """ try: kenlm_state = kenlm.State() except ImportError: raise ValueError("To use a language model, you need to install kenlm.") return kenlm_state
[docs] class KenlmScorer: r"""KenLM language model container class to consolidate functionality. This class is a wrapper around the KenLM language model. It provides functionality to score tokens and to get the initial state. Taken from: https://github.com/kensho-technologies/pyctcdecode Arguments --------- kenlm_model : kenlm.Model Kenlm model. unigrams : list List of known word unigrams. alpha : float Weight for language model during shallow fusion. beta : float Weight for length score adjustment of during scoring. unk_score_offset : float Amount of log score offset for unknown tokens. score_boundary : bool Whether to have kenlm respect boundaries when scoring. Example ------- >>> arpa_file = getfixture("tmpdir").join("bigram_hello.arpa") >>> arpa_file.write( ... "\\data\\\n" ... + "ngram 1=4\n" ... + "ngram 2=1\n\n" ... + "\\1-grams:\n" ... + "-1.0\t<s>\t-1.0\n" ... + "-1.0\t</s>\t-1.0\n" ... + "-1.0\tHello\t-0.23\n" ... + "-0.7\tworld\t-0.25\n\n" ... + "\\2-grams:\n" ... + "-0.3\tHello world\n\n" ... + "\\end\\" ... ) >>> model = kenlm.Model(str(arpa_file)) >>> scorer = KenlmScorer(kenlm_model=model, unigrams=["Hello", "world"]) >>> state = scorer.get_start_state() >>> score, new_state = scorer.score(state, "Hello") >>> round(score, 3) -0.803 """ def __init__( self, kenlm_model: "kenlm.Model", unigrams: Optional[Collection[str]] = None, alpha: float = 0.5, beta: float = 1.5, unk_score_offset: float = -10.0, score_boundary: bool = True, ) -> None: self._kenlm_model = kenlm_model if unigrams is None: logger.warning( "No known unigrams provided, decoding results might be a lot worse." ) unigram_set = set() char_trie = None else: unigram_set = _prepare_unigram_set(unigrams, self._kenlm_model) char_trie = CharTrie.fromkeys(unigram_set) self._unigram_set = unigram_set self._char_trie = char_trie self.alpha = alpha self.beta = beta self.unk_score_offset = unk_score_offset self.score_boundary = score_boundary @property def order(self) -> int: """Get the order of the n-gram language model.""" return cast(int, self._kenlm_model.order)
[docs] def get_start_state(self) -> KenlmState: """Get initial lm state.""" start_state = _get_empty_lm_state() if self.score_boundary: self._kenlm_model.BeginSentenceWrite(start_state) else: self._kenlm_model.NullContextWrite(start_state) return KenlmState(start_state)
def _get_raw_end_score(self, start_state: "kenlm.State") -> float: """Calculate final lm score.""" if self.score_boundary: end_state = _get_empty_lm_state() score: float = self._kenlm_model.BaseScore( start_state, "</s>", end_state ) else: score = 0.0 return score
[docs] def score_partial_token(self, partial_token: str) -> float: """Get partial token score.""" if self._char_trie is None: is_oov = 1.0 else: is_oov = int(self._char_trie.has_node(partial_token) == 0) unk_score = self.unk_score_offset * is_oov # if unk token length exceeds expected length then additionally decrease score if len(partial_token) > 6: unk_score = unk_score * len(partial_token) / 6 return unk_score
[docs] def score( self, prev_state, word: str, is_last_word: bool = False ) -> Tuple[float, KenlmState]: """Score word conditional on start state.""" if not isinstance(prev_state, KenlmState): raise AssertionError( f"Wrong input state type found. Expected KenlmState, got {type(prev_state)}" ) end_state = _get_empty_lm_state() lm_score = self._kenlm_model.BaseScore( prev_state.state, word, end_state ) # override UNK prob. use unigram set if we have because it's faster if ( len(self._unigram_set) > 0 and word not in self._unigram_set or word not in self._kenlm_model ): lm_score += self.unk_score_offset # add end of sentence context if needed if is_last_word: # note that we want to return the unmodified end_state to keep extension capabilities lm_score = lm_score + self._get_raw_end_score(end_state) lm_score = self.alpha * lm_score * 1.0 / math.log10(math.e) + self.beta return lm_score, KenlmState(end_state)