Source code for speechbrain.utils.edit_distance

"""Edit distance and WER computation.

Authors
 * Aku Rouhe 2020
 * Salima Mdhaffar 2021
"""

import collections

EDIT_SYMBOLS = {
    "eq": "=",  # when tokens are equal
    "ins": "I",
    "del": "D",
    "sub": "S",
}


# NOTE: There is a danger in using mutables as default arguments, as they are
# only initialized once, and not every time the function is run. However,
# here the default is not actually ever mutated,
# and simply serves as an empty Counter.
[docs] def accumulatable_wer_stats(refs, hyps, stats=collections.Counter()): """Computes word error rate and the related counts for a batch. Can also be used to accumulate the counts over many batches, by passing the output back to the function in the call for the next batch. Arguments ---------- ref : iterable Batch of reference sequences. hyp : iterable Batch of hypothesis sequences. stats : collections.Counter The running statistics. Pass the output of this function back as this parameter to accumulate the counts. It may be cleanest to initialize the stats yourself; then an empty collections.Counter() should be used. Returns ------- collections.Counter The updated running statistics, with keys: * "WER" - word error rate * "insertions" - number of insertions * "deletions" - number of deletions * "substitutions" - number of substitutions * "num_ref_tokens" - number of reference tokens Example ------- >>> import collections >>> batches = [[[[1,2,3],[4,5,6]], [[1,2,4],[5,6]]], ... [[[7,8], [9]], [[7,8], [10]]]] >>> stats = collections.Counter() >>> for batch in batches: ... refs, hyps = batch ... stats = accumulatable_wer_stats(refs, hyps, stats) >>> print("%WER {WER:.2f}, {num_ref_tokens} ref tokens".format(**stats)) %WER 33.33, 9 ref tokens """ updated_stats = stats + _batch_stats(refs, hyps) if updated_stats["num_ref_tokens"] == 0: updated_stats["WER"] = float("nan") else: num_edits = sum( [ updated_stats["insertions"], updated_stats["deletions"], updated_stats["substitutions"], ] ) updated_stats["WER"] = ( 100.0 * num_edits / updated_stats["num_ref_tokens"] ) return updated_stats
def _batch_stats(refs, hyps): """Internal function which actually computes the counts. Used by accumulatable_wer_stats Arguments ---------- ref : iterable Batch of reference sequences. hyp : iterable Batch of hypothesis sequences. Returns ------- collections.Counter Edit statistics over the batch, with keys: * "insertions" - number of insertions * "deletions" - number of deletions * "substitutions" - number of substitutions * "num_ref_tokens" - number of reference tokens Example ------- >>> from speechbrain.utils.edit_distance import _batch_stats >>> batch = [[[1,2,3],[4,5,6]], [[1,2,4],[5,6]]] >>> refs, hyps = batch >>> print(_batch_stats(refs, hyps)) Counter({'num_ref_tokens': 6, 'substitutions': 1, 'deletions': 1}) """ if len(refs) != len(hyps): raise ValueError( "The reference and hypothesis batches are not of the same size" ) stats = collections.Counter() for ref_tokens, hyp_tokens in zip(refs, hyps): table = op_table(ref_tokens, hyp_tokens) edits = count_ops(table) stats += edits stats["num_ref_tokens"] += len(ref_tokens) return stats
[docs] def op_table(a, b): """Table of edit operations between a and b. Solves for the table of edit operations, which is mainly used to compute word error rate. The table is of size ``[|a|+1, |b|+1]``, and each point ``(i, j)`` in the table has an edit operation. The edit operations can be deterministically followed backwards to find the shortest edit path to from ``a[:i-1] to b[:j-1]``. Indexes of zero (``i=0`` or ``j=0``) correspond to an empty sequence. The algorithm itself is well known, see `Levenshtein distance <https://en.wikipedia.org/wiki/Levenshtein_distance>`_ Note that in some cases there are multiple valid edit operation paths which lead to the same edit distance minimum. Arguments --------- a : iterable Sequence for which the edit operations are solved. b : iterable Sequence for which the edit operations are solved. Returns ------- list List of lists, Matrix, Table of edit operations. Example ------- >>> ref = [1,2,3] >>> hyp = [1,2,4] >>> for row in op_table(ref, hyp): ... print(row) ['=', 'I', 'I', 'I'] ['D', '=', 'I', 'I'] ['D', 'D', '=', 'I'] ['D', 'D', 'D', 'S'] """ # For the dynamic programming algorithm, only two rows are really needed: # the one currently being filled in, and the previous one # The following is also the right initialization prev_row = [j for j in range(len(b) + 1)] curr_row = [0] * (len(b) + 1) # Just init to zero # For the edit operation table we will need the whole matrix. # We will initialize the table with no-ops, so that we only need to change # where an edit is made. table = [ [EDIT_SYMBOLS["eq"] for j in range(len(b) + 1)] for i in range(len(a) + 1) ] # We already know the operations on the first row and column: for i in range(len(a) + 1): table[i][0] = EDIT_SYMBOLS["del"] for j in range(len(b) + 1): table[0][j] = EDIT_SYMBOLS["ins"] table[0][0] = EDIT_SYMBOLS["eq"] # The rest of the table is filled in row-wise: for i, a_token in enumerate(a, start=1): curr_row[0] += 1 # This trick just deals with the first column. for j, b_token in enumerate(b, start=1): # The dynamic programming algorithm cost rules insertion_cost = curr_row[j - 1] + 1 deletion_cost = prev_row[j] + 1 substitution = 0 if a_token == b_token else 1 substitution_cost = prev_row[j - 1] + substitution # Here copying the Kaldi compute-wer comparison order, which in # ties prefers: # insertion > deletion > substitution if ( substitution_cost < insertion_cost and substitution_cost < deletion_cost ): curr_row[j] = substitution_cost # Again, note that if not substitution, the edit table already # has the correct no-op symbol. if substitution: table[i][j] = EDIT_SYMBOLS["sub"] elif deletion_cost < insertion_cost: curr_row[j] = deletion_cost table[i][j] = EDIT_SYMBOLS["del"] else: curr_row[j] = insertion_cost table[i][j] = EDIT_SYMBOLS["ins"] # Move to the next row: prev_row[:] = curr_row[:] return table
[docs] def alignment(table): """Get the edit distance alignment from an edit op table. Walks back an edit operations table, produced by calling ``table(a, b)``, and collects the edit distance alignment of a to b. The alignment shows which token in a corresponds to which token in b. Note that the alignment is monotonic, one-to-zero-or-one. Arguments ---------- table : list Edit operations table from ``op_table(a, b)``. Returns ------- list Schema: ``[(str <edit-op>, int-or-None <i>, int-or-None <j>),]`` List of edit operations, and the corresponding indices to a and b. See the EDIT_SYMBOLS dict for the edit-ops. The i indexes a, j indexes b, and the indices can be None, which means aligning to nothing. Example ------- >>> # table for a=[1,2,3], b=[1,2,4]: >>> table = [['I', 'I', 'I', 'I'], ... ['D', '=', 'I', 'I'], ... ['D', 'D', '=', 'I'], ... ['D', 'D', 'D', 'S']] >>> print(alignment(table)) [('=', 0, 0), ('=', 1, 1), ('S', 2, 2)] """ # The alignment will be the size of the longer sequence. # form: [(op, a_index, b_index)], index is None when aligned to empty alignment = [] # Now we'll walk back the table to get the alignment. i = len(table) - 1 j = len(table[0]) - 1 while not (i == 0 and j == 0): if i == 0: j -= 1 alignment.insert(0, (EDIT_SYMBOLS["ins"], None, j)) elif j == 0: i -= 1 alignment.insert(0, (EDIT_SYMBOLS["del"], i, None)) else: if table[i][j] == EDIT_SYMBOLS["ins"]: j -= 1 alignment.insert(0, (EDIT_SYMBOLS["ins"], None, j)) elif table[i][j] == EDIT_SYMBOLS["del"]: i -= 1 alignment.insert(0, (EDIT_SYMBOLS["del"], i, None)) elif table[i][j] == EDIT_SYMBOLS["sub"]: i -= 1 j -= 1 alignment.insert(0, (EDIT_SYMBOLS["sub"], i, j)) else: i -= 1 j -= 1 alignment.insert(0, (EDIT_SYMBOLS["eq"], i, j)) return alignment
[docs] def count_ops(table): """Count the edit operations in the shortest edit path in edit op table. Walks back an edit operations table produced by table(a, b) and counts the number of insertions, deletions, and substitutions in the shortest edit path. This information is typically used in speech recognition to report the number of different error types separately. Arguments ---------- table : list Edit operations table from ``op_table(a, b)``. Returns ------- collections.Counter The counts of the edit operations, with keys: * "insertions" * "deletions" * "substitutions" NOTE: not all of the keys might appear explicitly in the output, but for the missing keys collections. The counter will return 0. Example ------- >>> table = [['I', 'I', 'I', 'I'], ... ['D', '=', 'I', 'I'], ... ['D', 'D', '=', 'I'], ... ['D', 'D', 'D', 'S']] >>> print(count_ops(table)) Counter({'substitutions': 1}) """ edits = collections.Counter() # Walk back the table, gather the ops. i = len(table) - 1 j = len(table[0]) - 1 while not (i == 0 and j == 0): if i == 0: edits["insertions"] += 1 j -= 1 elif j == 0: edits["deletions"] += 1 i -= 1 else: if table[i][j] == EDIT_SYMBOLS["ins"]: edits["insertions"] += 1 j -= 1 elif table[i][j] == EDIT_SYMBOLS["del"]: edits["deletions"] += 1 i -= 1 else: if table[i][j] == EDIT_SYMBOLS["sub"]: edits["substitutions"] += 1 i -= 1 j -= 1 return edits
def _batch_to_dict_format(ids, seqs): # Used by wer_details_for_batch return dict(zip(ids, seqs))
[docs] def wer_details_for_batch(ids, refs, hyps, compute_alignments=False): """Convenient batch interface for ``wer_details_by_utterance``. ``wer_details_by_utterance`` can handle missing hypotheses, but sometimes (e.g. CTC training with greedy decoding) they are not needed, and this is a convenient interface in that case. Arguments --------- ids : list, torch.tensor Utterance ids for the batch. refs : list, torch.tensor Reference sequences. hyps : list, torch.tensor Hypothesis sequences. compute_alignments : bool, optional Whether to compute alignments or not. If computed, the details will also store the refs and hyps. (default: False) Returns ------- list See ``wer_details_by_utterance`` Example ------- >>> ids = [['utt1'], ['utt2']] >>> refs = [[['a','b','c']], [['d','e']]] >>> hyps = [[['a','b','d']], [['d','e']]] >>> wer_details = [] >>> for ids_batch, refs_batch, hyps_batch in zip(ids, refs, hyps): ... details = wer_details_for_batch(ids_batch, refs_batch, hyps_batch) ... wer_details.extend(details) >>> print(wer_details[0]['key'], ":", ... "{:.2f}".format(wer_details[0]['WER'])) utt1 : 33.33 """ refs = _batch_to_dict_format(ids, refs) hyps = _batch_to_dict_format(ids, hyps) return wer_details_by_utterance( refs, hyps, compute_alignments=compute_alignments, scoring_mode="strict" )
[docs] def wer_details_by_utterance( ref_dict, hyp_dict, compute_alignments=False, scoring_mode="strict" ): """Computes a wealth WER info about each single utterance. This info can then be used to compute summary details (WER, SER). Arguments --------- ref_dict : dict Should be indexable by utterance ids, and return the reference tokens for each utterance id as iterable hyp_dict : dict Should be indexable by utterance ids, and return the hypothesis tokens for each utterance id as iterable compute_alignments : bool Whether alignments should also be saved. This also saves the tokens themselves, as they are probably required for printing the alignments. scoring_mode : {'strict', 'all', 'present'} How to deal with missing hypotheses (reference utterance id not found in hyp_dict). * 'strict': Raise error for missing hypotheses. * 'all': Score missing hypotheses as empty. * 'present': Only score existing hypotheses. Returns ------- list A list with one entry for every reference utterance. Each entry is a dict with keys: * "key": utterance id * "scored": (bool) Whether utterance was scored. * "hyp_absent": (bool) True if a hypothesis was NOT found. * "hyp_empty": (bool) True if hypothesis was considered empty (either because it was empty, or not found and mode 'all'). * "num_edits": (int) Number of edits in total. * "num_ref_tokens": (int) Number of tokens in the reference. * "WER": (float) Word error rate of the utterance. * "insertions": (int) Number of insertions. * "deletions": (int) Number of deletions. * "substitutions": (int) Number of substitutions. * "alignment": If compute_alignments is True, alignment as list, see ``speechbrain.utils.edit_distance.alignment``. If compute_alignments is False, this is None. * "ref_tokens": (iterable) The reference tokens only saved if alignments were computed, else None. * "hyp_tokens": (iterable) the hypothesis tokens, only saved if alignments were computed, else None. Raises ------ KeyError If scoring mode is 'strict' and a hypothesis is not found. """ details_by_utterance = [] for key, ref_tokens in ref_dict.items(): # Initialize utterance_details utterance_details = { "key": key, "scored": False, "hyp_absent": None, "hyp_empty": None, "num_edits": None, "num_ref_tokens": len(ref_tokens), "WER": None, "insertions": None, "deletions": None, "substitutions": None, "alignment": None, "ref_tokens": ref_tokens if compute_alignments else None, "hyp_tokens": None, } if key in hyp_dict: utterance_details.update({"hyp_absent": False}) hyp_tokens = hyp_dict[key] elif scoring_mode == "all": utterance_details.update({"hyp_absent": True}) hyp_tokens = [] elif scoring_mode == "present": utterance_details.update({"hyp_absent": True}) details_by_utterance.append(utterance_details) continue # Skip scoring this utterance elif scoring_mode == "strict": raise KeyError( "Key " + key + " in reference but missing in hypothesis and strict mode on." ) else: raise ValueError("Invalid scoring mode: " + scoring_mode) # Compute edits for this utterance table = op_table(ref_tokens, hyp_tokens) ops = count_ops(table) # Take into account "" outputs as empty if ref_tokens[0] == "" and hyp_tokens[0] == "": num_ref_tokens = 0 else: num_ref_tokens = len(ref_tokens) # Update the utterance-level details if we got this far: utterance_details.update( { "scored": True, "hyp_empty": True if len(hyp_tokens) == 0 else False, # This also works for e.g. torch tensors "num_edits": sum(ops.values()), "num_ref_tokens": num_ref_tokens, "WER": 100.0 * sum(ops.values()) / len(ref_tokens), "insertions": ops["insertions"], "deletions": ops["deletions"], "substitutions": ops["substitutions"], "alignment": alignment(table) if compute_alignments else None, "ref_tokens": ref_tokens if compute_alignments else None, "hyp_tokens": hyp_tokens if compute_alignments else None, } ) details_by_utterance.append(utterance_details) return details_by_utterance
[docs] def wer_summary(details_by_utterance): """ Computes summary stats from the output of details_by_utterance Summary stats like WER Arguments --------- details_by_utterance : list See the output of wer_details_by_utterance Returns ------- dict Dictionary with keys: * "WER": (float) Word Error Rate. * "SER": (float) Sentence Error Rate (percentage of utterances which had at least one error). * "num_edits": (int) Total number of edits. * "num_scored_tokens": (int) Total number of tokens in scored reference utterances (a missing hypothesis might still have been scored with 'all' scoring mode). * "num_erraneous_sents": (int) Total number of utterances which had at least one error. * "num_scored_sents": (int) Total number of utterances which were scored. * "num_absent_sents": (int) Hypotheses which were not found. * "num_ref_sents": (int) Number of all reference utterances. * "insertions": (int) Total number of insertions. * "deletions": (int) Total number of deletions. * "substitutions": (int) Total number of substitutions. NOTE: Some cases lead to ambiguity over number of insertions, deletions and substitutions. We aim to replicate Kaldi compute_wer numbers. """ # Build the summary details: ins = dels = subs = 0 num_scored_tokens = ( num_scored_sents ) = num_edits = num_erraneous_sents = num_absent_sents = num_ref_sents = 0 for dets in details_by_utterance: num_ref_sents += 1 if dets["scored"]: num_scored_sents += 1 num_scored_tokens += dets["num_ref_tokens"] ins += dets["insertions"] dels += dets["deletions"] subs += dets["substitutions"] num_edits += dets["num_edits"] if dets["num_edits"] > 0: num_erraneous_sents += 1 if dets["hyp_absent"]: num_absent_sents += 1 if num_scored_tokens != 0: WER = 100.0 * num_edits / num_scored_tokens else: WER = 0.0 wer_details = { "WER": WER, "SER": 100.0 * num_erraneous_sents / num_scored_sents, "num_edits": num_edits, "num_scored_tokens": num_scored_tokens, "num_erraneous_sents": num_erraneous_sents, "num_scored_sents": num_scored_sents, "num_absent_sents": num_absent_sents, "num_ref_sents": num_ref_sents, "insertions": ins, "deletions": dels, "substitutions": subs, } return wer_details
[docs] def wer_details_by_speaker(details_by_utterance, utt2spk): """Compute word error rate and another salient info grouping by speakers. Arguments --------- details_by_utterance : list See the output of wer_details_by_utterance utt2spk : dict Map from utterance id to speaker id Returns ------- dict Maps speaker id to a dictionary of the statistics, with keys: * "speaker": Speaker id, * "num_edits": (int) Number of edits in total by this speaker. * "insertions": (int) Number insertions by this speaker. * "dels": (int) Number of deletions by this speaker. * "subs": (int) Number of substitutions by this speaker. * "num_scored_tokens": (int) Number of scored reference tokens by this speaker (a missing hypothesis might still have been scored with 'all' scoring mode). * "num_scored_sents": (int) number of scored utterances by this speaker. * "num_erraneous_sents": (int) number of utterance with at least one error, by this speaker. * "num_absent_sents": (int) number of utterances for which no hypotheses was found, by this speaker. * "num_ref_sents": (int) number of utterances by this speaker in total. """ # Build the speakerwise details: details_by_speaker = {} for dets in details_by_utterance: speaker = utt2spk[dets["key"]] spk_dets = details_by_speaker.setdefault( speaker, collections.Counter( { "speaker": speaker, "insertions": 0, "dels": 0, "subs": 0, "num_scored_tokens": 0, "num_scored_sents": 0, "num_edits": 0, "num_erraneous_sents": 0, "num_absent_sents": 0, "num_ref_sents": 0, } ), ) utt_stats = collections.Counter() if dets["hyp_absent"]: utt_stats.update({"num_absent_sents": 1}) if dets["scored"]: utt_stats.update( { "num_scored_sents": 1, "num_scored_tokens": dets["num_ref_tokens"], "insertions": dets["insertions"], "dels": dets["deletions"], "subs": dets["substitutions"], "num_edits": dets["num_edits"], } ) if dets["num_edits"] > 0: utt_stats.update({"num_erraneous_sents": 1}) spk_dets.update(utt_stats) # We will in the end return a list of normal dicts # We want the output to be sortable details_by_speaker_dicts = [] # Now compute speakerwise summary details for speaker, spk_dets in details_by_speaker.items(): spk_dets["speaker"] = speaker if spk_dets["num_scored_sents"] > 0: spk_dets["WER"] = ( 100.0 * spk_dets["num_edits"] / spk_dets["num_scored_tokens"] ) spk_dets["SER"] = ( 100.0 * spk_dets["num_erraneous_sents"] / spk_dets["num_scored_sents"] ) else: spk_dets["WER"] = None spk_dets["SER"] = None details_by_speaker_dicts.append(spk_dets) return details_by_speaker_dicts
[docs] def top_wer_utts(details_by_utterance, top_k=20): """ Finds the k utterances with highest word error rates. Useful for diagnostic purposes, to see where the system is making the most mistakes. Returns results utterances which were not empty i.e. had to have been present in the hypotheses, with output produced Arguments --------- details_by_utterance : list See output of wer_details_by_utterance. top_k : int Number of utterances to return. Returns ------- list List of at most K utterances, with the highest word error rates, which were not empty. The utterance dict has the same keys as details_by_utterance. """ scored_utterances = [ dets for dets in details_by_utterance if dets["scored"] ] utts_by_wer = sorted( scored_utterances, key=lambda d: d["WER"], reverse=True ) top_non_empty = [] top_empty = [] while utts_by_wer and ( len(top_non_empty) < top_k or len(top_empty) < top_k ): utt = utts_by_wer.pop(0) if utt["hyp_empty"] and len(top_empty) < top_k: top_empty.append(utt) elif not utt["hyp_empty"] and len(top_non_empty) < top_k: top_non_empty.append(utt) return top_non_empty, top_empty
[docs] def top_wer_spks(details_by_speaker, top_k=10): """ Finds the K speakers with the highest word error rates. Useful for diagnostic purposes. Arguments --------- details_by_speaker : list See output of wer_details_by_speaker. top_k : int Number of seakers to return. Returns ------- list List of at most K dicts (with the same keys as details_by_speaker) of speakers sorted by WER. """ scored_speakers = [ dets for dets in details_by_speaker if dets["num_scored_sents"] > 0 ] spks_by_wer = sorted(scored_speakers, key=lambda d: d["WER"], reverse=True) if len(spks_by_wer) >= top_k: return spks_by_wer[:top_k] else: return spks_by_wer