Source code for speechbrain.lm.arpa

r"""
Tools for working with ARPA format N-gram models

Expects the ARPA format to have:
- a \data\ header
- counts of ngrams in the order that they are later listed
- line breaks between \data\ and \n-grams: sections
- \end\
E.G.
    ```
    \data\
    ngram 1=2
    ngram 2=1

    \1-grams:
    -1.0000 Hello -0.23
    -0.6990 world -0.2553

    \2-grams:
    -0.2553 Hello world

    \end\
    ```


Example
-------
>>> # This example loads an ARPA model and queries it with BackoffNgramLM
>>> import io
>>> from speechbrain.lm.ngram import BackoffNgramLM
>>> # First we'll put an ARPA format model in TextIO and load it:
>>> with io.StringIO() as f:
...     print("Anything can be here", file=f)
...     print("", file=f)
...     print("\\data\\", file=f)
...     print("ngram 1=2", file=f)
...     print("ngram 2=3", file=f)
...     print("", file=f)  # Ends data section
...     print("\\1-grams:", file=f)
...     print("-0.6931 a", file=f)
...     print("-0.6931 b 0.", file=f)
...     print("", file=f)  # Ends unigram section
...     print("\\2-grams:", file=f)
...     print("-0.6931 a a", file=f)
...     print("-0.6931 a b", file=f)
...     print("-0.6931 b a", file=f)
...     print("", file=f)  # Ends bigram section
...     print("\\end\\", file=f)  # Ends whole file
...     _ = f.seek(0)
...     num_grams, ngrams, backoffs = read_arpa(f)
>>> # The output of read arpa is already formatted right for the query class:
>>> lm = BackoffNgramLM(ngrams, backoffs)
>>> lm.logprob("a", context = tuple())
-0.6931
>>> # Query that requires a backoff:
>>> lm.logprob("b", context = ("b",))
-0.6931

Authors
 * Aku Rouhe 2020
 * Pierre Champion 2023
"""
import collections
import logging
from pathlib import Path
from typing import Union

logger = logging.getLogger(__name__)


[docs] def read_arpa(fstream): r""" Reads an ARPA format N-gram language model from a stream Arguments --------- fstream : TextIO Text file stream (as commonly returned by open()) to read the model from. Returns ------- dict Maps N-gram orders to the number ngrams of that order. Essentially the \data\ section of an ARPA format file. dict The log probabilities (first column) in the ARPA file. This is a triply nested dict. The first layer is indexed by N-gram order (integer). The second layer is indexed by the context (tuple of tokens). The third layer is indexed by tokens, and maps to the log prob. This format is compatible with `speechbrain.lm.ngram.BackoffNGramLM` Example: In ARPA format, log(P(fox|a quick red)) = -5.3 is expressed: `-5.3 a quick red fox` And to access that probability, use: `ngrams_by_order[4][('a', 'quick', 'red')]['fox']` dict The log backoff weights (last column) in the ARPA file. This is a doubly nested dict. The first layer is indexed by N-gram order (integer). The second layer is indexed by the backoff history (tuple of tokens) i.e. the context on which the probability distribution is conditioned on. This maps to the log weights. This format is compatible with `speechbrain.lm.ngram.BackoffNGramLM` Example: If log(P(fox|a quick red)) is not listed, we find log(backoff(a quick red)) = -23.4 which in ARPA format is: `<logp> a quick red -23.4` And to access that here, use: `backoffs_by_order[3][('a', 'quick', 'red')]` Raises ------ ValueError If no LM is found or the file is badly formatted. """ # Developer's note: # This is a long function. # It is because we support cases where a new section starts suddenly without # an empty line in between. # # \data\ section: _find_data_section(fstream) num_ngrams = {} for line in fstream: line = line.strip() if line[:5] == "ngram": lhs, rhs = line.split("=") order = int(lhs.split()[1]) num_grams = int(rhs) num_ngrams[order] = num_grams elif not line: # Normal case, empty line ends section ended, order = _next_section_or_end(fstream) break # Good, proceed to next section elif _starts_ngrams_section(line): # No empty line between sections ended = False order = _parse_order(line) break # Good, proceed to next section else: raise ValueError("Not a properly formatted line") # At this point: # ended == False # type(order) == int # # \N-grams: sections # NOTE: This is the section that most time is spent on, so it's been written # with processing speed in mind. ngrams_by_order = {} backoffs_by_order = {} while not ended: probs = collections.defaultdict(dict) backoffs = {} backoff_line_length = order + 2 # Use try-except because it is faster than always checking try: for line in fstream: line = line.strip() all_parts = tuple(line.split()) prob = float(all_parts[0]) if len(all_parts) == backoff_line_length: context = all_parts[1:-2] token = all_parts[-2] backoff = float(all_parts[-1]) backoff_context = context + (token,) backoffs[backoff_context] = backoff else: context = all_parts[1:-1] token = all_parts[-1] probs[context][token] = prob except (IndexError, ValueError): ngrams_by_order[order] = probs backoffs_by_order[order] = backoffs if not line: # Normal case, empty line ends section ended, order = _next_section_or_end(fstream) elif _starts_ngrams_section(line): # No empty line between sections ended = False order = _parse_order(line) elif _ends_arpa(line): # No empty line before End of file ended = True order = None else: raise ValueError("Not a properly formatted ARPA file") # Got to the \end\. Still have to check whether all promised sections were # delivered. if not num_ngrams.keys() == ngrams_by_order.keys(): raise ValueError("Not a properly formatted ARPA file") return num_ngrams, ngrams_by_order, backoffs_by_order
def _find_data_section(fstream): r""" Reads (lines) from the stream until the \data\ header is found. """ for line in fstream: if line[:6] == "\\data\\": return # If we get here, no data header found raise ValueError("Not a properly formatted ARPA file") def _next_section_or_end(fstream): """ Returns ------- bool Whether end was found. int The order of section that starts """ for line in fstream: line = line.strip() if _starts_ngrams_section(line): order = _parse_order(line) return False, order if _ends_arpa(line): return True, None # If we got here, it's not a properly formatted file raise ValueError("Not a properly formatted ARPA file") def _starts_ngrams_section(line): return line.strip().endswith("-grams:") def _parse_order(line): order = int(line[1:].split("-")[0]) return order def _ends_arpa(line): return line == "\\end\\"
[docs] def arpa_to_fst( words_txt: Union[str, Path], in_arpa: Union[str, Path], out_fst: Union[str, Path], ngram_order: int, disambig_symbol: str = "#0", cache: bool = True, ): r""" Use kaldilm to convert an ARPA LM to FST. For example, you could use speechbrain.lm.train_ngram to create an ARPA LM and then use this function to convert it to an FST. It is worth noting that if the fst already exists in the output_dir, then they will not be converted again (so you may need to delete them by hand if you, at any point, change your ARPA model). Arguments --------- words_txt: str | Path path to the words.txt file created by prepare_lang. in_arpa: str | Path Path to an ARPA LM to convert to an FST. out_fst: str | Path Path to where the fst will be saved. ngram_order: int ARPA (and FST) ngram order. disambig_symbol: str the disambiguation symbol to use. cache: bool Whether or not to re-create the fst.txt file if it already exist. Raises --------- ImportError: If kaldilm is not installed. Example ------- >>> from speechbrain.lm.arpa import arpa_to_fst >>> # Create a small arpa model >>> arpa_file = getfixture('tmpdir').join("bigram.arpa") >>> arpa_file.write( ... "Anything can be here\n" ... + "\n" ... + "\\data\\\n" ... + "ngram 1=3\n" ... + "ngram 2=4\n" ... + "\n" ... + "\\1-grams:\n" ... + "0 <s>\n" ... + "-0.6931 a\n" ... + "-0.6931 b 0.\n" ... + "" # Ends unigram section ... + "\\2-grams:\n" ... + "-0.6931 <s> a\n" ... + "-0.6931 a a\n" ... + "-0.6931 a b\n" ... + "-0.6931 b a\n" ... + "\n" # Ends bigram section ... + "\\end\\\n") # Ends whole file >>> # Create words vocab >>> vocav = getfixture('tmpdir').join("words.txt") >>> vocav.write( ... "a 1\n" ... + "b 2\n" ... + "<s> 3\n" ... + "#0 4") # Ends whole file >>> out = getfixture('tmpdir').join("bigram.txt.fst") >>> arpa_to_fst(vocav, arpa_file, out, 2) """ try: from kaldilm.arpa2fst import arpa2fst except ImportError: # This error will occur when there is fst LM in the provided lm_dir # and we are trying to create it by converting an ARPA LM to FST. # For this, we need to install kaldilm. raise ImportError( "Optional dependencies must be installed to use kaldilm.\n" "Install using `pip install kaldilm`." ) if cache and out_fst.exists(): return if not in_arpa.exists(): raise FileNotFoundError( f"{in_arpa} not found while trying to create" f" the {ngram_order} FST." ) try: logger.info(f"Converting arpa LM '{in_arpa}' to FST") s = arpa2fst( input_arpa=str(in_arpa), disambig_symbol=disambig_symbol, read_symbol_table=str(words_txt), max_order=ngram_order, ) except Exception as e: logger.info( f"Failed to create {ngram_order}-gram FST from input={in_arpa}" f", disambig_symbol={disambig_symbol}," f" read_symbol_table={words_txt}" ) raise e logger.info(f"Writing {out_fst}") with open(out_fst, "w") as f: f.write(s)