Source code for speechbrain.dataio.wer

"""WER print functions.

The functions here are used to print the computed statistics
with human-readable formatting.
They have a file argument, but you can also just use
contextlib.redirect_stdout, which may give a nicer syntax.

Authors
 * Aku Rouhe 2020
"""
import sys
from speechbrain.utils import edit_distance










# The following internal functions are used to
# print out more specific things
def _print_top_wer_utts(top_non_empty, top_empty, file=sys.stdout):
    print("=" * 80, file=file)
    print("UTTERANCES WITH HIGHEST WER", file=file)
    if top_non_empty:
        print(
            "Non-empty hypotheses -- utterances for which output was produced:",
            file=file,
        )
        for dets in top_non_empty:
            print("{key} %WER {WER:.2f}".format(**dets), file=file)
    else:
        print("No utterances which had produced output!", file=file)
    if top_empty:
        print(
            "Empty hypotheses -- utterances for which no output was produced:",
            file=file,
        )
        for dets in top_empty:
            print("{key} %WER {WER:.2f}".format(**dets), file=file)
    else:
        print("No utterances which had not produced output!", file=file)


def _print_top_wer_spks(spks_by_wer, file=sys.stdout):
    print("=" * 80, file=file)
    print("SPEAKERS WITH HIGHEST WER", file=file)
    for dets in spks_by_wer:
        print("{speaker} %WER {WER:.2f}".format(**dets), file=file)


def _print_alignment(
    alignment, a, b, empty_symbol="<eps>", separator=" ; ", file=sys.stdout
):
    # First, get equal length text for all:
    a_padded = []
    b_padded = []
    ops_padded = []
    for op, i, j in alignment:  # i indexes a, j indexes b
        op_string = str(op)
        a_string = str(a[i]) if i is not None else empty_symbol
        b_string = str(b[j]) if j is not None else empty_symbol
        # NOTE: the padding does not actually compute printed length,
        # but hopefully we can assume that printed length is
        # at most the str len
        pad_length = max(len(op_string), len(a_string), len(b_string))
        a_padded.append(a_string.center(pad_length))
        b_padded.append(b_string.center(pad_length))
        ops_padded.append(op_string.center(pad_length))
    # Then print, in the order Ref, op, Hyp
    print(separator.join(a_padded), file=file)
    print(separator.join(ops_padded), file=file)
    print(separator.join(b_padded), file=file)


def _print_alignments_global_header(
    empty_symbol="<eps>", separator=" ; ", file=sys.stdout
):
    print("=" * 80, file=file)
    print("ALIGNMENTS", file=file)
    print("", file=file)
    print("Format:", file=file)
    print("<utterance-id>, WER DETAILS", file=file)
    # Print the format with the actual
    # print_alignment function, using artificial data:
    a = ["reference", "on", "the", "first", "line"]
    b = ["and", "hypothesis", "on", "the", "third"]
    alignment = [
        (edit_distance.EDIT_SYMBOLS["ins"], None, 0),
        (edit_distance.EDIT_SYMBOLS["sub"], 0, 1),
        (edit_distance.EDIT_SYMBOLS["eq"], 1, 2),
        (edit_distance.EDIT_SYMBOLS["eq"], 2, 3),
        (edit_distance.EDIT_SYMBOLS["sub"], 3, 4),
        (edit_distance.EDIT_SYMBOLS["del"], 4, None),
    ]
    _print_alignment(
        alignment,
        a,
        b,
        file=file,
        empty_symbol=empty_symbol,
        separator=separator,
    )


def _print_alignment_header(wer_details, file=sys.stdout):
    print("=" * 80, file=file)
    print(
        "{key}, %WER {WER:.2f} [ {num_edits} / {num_ref_tokens}, {insertions} ins, {deletions} del, {substitutions} sub ]".format(  # noqa
            **wer_details
        ),
        file=file,
    )