Open In Colab to execute or view/download this notebook on GitHub

Metrics for Speech Recognition

Estimating the accuracy of a speech recognition model is not a trivial problem. The Word Error Rate (WER) and Character Error Rate (CER) metrics are standard, but some research has been trying to develop alternatives that better correlate with human evaluation (such as SemDist).

This tutorial introduces some alternative ASR metrics and their flexible integration into SpeechBrain, which can help you research, use or develop new metrics, with copy&paste-ready hyperparameters.

SpeechBrain v1.0.1 via PR #2451 introduced support and tooling for the metrics suggested by Qualitative Evaluation of Language Model Rescoring in Automatic Speech Recognition. We recommend that you read this, as some of the metrics won’t be explained in detail here.

%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
%pip install spacy
%pip install flair

Some boilerplate and test data downloading follows…

from hyperpyyaml import load_hyperpyyaml
from collections import defaultdict
%%capture
!wget https://raw.githubusercontent.com/thibault-roux/hypereval/main/data/Exemple/refhyp.txt -O refhyp.txt
!head refhyp.txt
bonsoir à tous bienvenue c' est bfm story en direct jusqu' à dix neuf heures à la une	à tous bienvenue c' est bfm story en direct jusqu' à dix neuf heures	_
de bfm story ce soir la zone euro va t elle encore vivre un été meurtrier l' allemagne première économie européenne pourrait perdre son triple a la situation se détériore en espagne	bfm story ce soir la zone euro va t elle encore vive été meurtrier allemagne première économie européenne pourrait perdre son triple a la situation se détériore en espagne	_
pourquoi ces nouvelles tensions nous serons avec un spécialiste de l' espagne et nous serons avec le député socialiste rapporteur du budget en direct de l' assemblée nationale christian eckert	ces nouvelles tensions sont avec un spécialiste de l' espagne et nous serons avec le député socialiste rapporteur du budget de l' assemblée nationale christian eckert	_
à la une également la syrie et les armes chimiques la russie demande au régime de bachar al assad de ne pas utiliser ces armes	la une également la syrie et les armes chimiques la russie demande au régime de bachar el assad ne pas utiliser ses armes	_
de quel arsenal dispose l' armée syrienne	quelle arsenal dispose l' armée syrienne	_
quels dégats pourraient provoquer ces armes chimiques	dégâts pourraient provoquer ses armes chimiques	_
un spécialiste jean pierre daguzan nous répondra sur le plateau de bfm story et puis	spécialistes ont bien accusant nous répondra sur le plateau de bfm story puis	_
après la droite populaire la droite humaniste voici la droite forte deux jeunes pousses de l' ump guillaume peltier et geoffroy didier lancent ce nouveau mouvement pourquoi faire ils sont mes invités ce soir	la droite populaire la droite humaniste voici la droite forte deux jeunes pousses de l' ump guillaume peltier geoffroy didier migaud pour quoi faire ils sont mes invités ce soir	_
et puis c(ette) cette fois ci c' est vraiment la fin la fin de france soir liquidé par le tribunal de commerce nous en parlerons avec son tout dernier rédacteur en chef dominique de montvalon	cette fois ci c' est vraiment la fin à la fin de france soir liquidé par le tribunal de commerce nous en parlerons avec tout dernier rédacteur en chef dominique de montvalon	_
damien gourlet bonsoir avec vous ce qu' il faut retenir ce soir dans l' actualité l' actualité ce sont encore les incendies en espagne	damien gourlet bonsoir olivier avec vous ce qu' il faut retenir ce soir dans l' actualité actualité se sont encore les incendies en espagne	_
refs = []
hyps = []

# some preprocessing for the example file + load uposer mapping to a test file

def split_norm_text(s: str):
    # s = s.replace("' ", "'")

    if s != "":
        return s.split(" ")

    return s

with open("refhyp.txt") as f:
    for refhyp in f.read().splitlines():
        if len(refhyp) <= 1:
            continue

        refhyp = refhyp.split("\t")
        refs.append(split_norm_text(refhyp[0]))
        hyps.append(split_norm_text(refhyp[1]))

with open("uposer.json", "w") as wf:
    wf.write("""[
    ["ADJ", "ADJFP", "ADJFS", "ADJMP", "ADJMS"],
    ["NUM", "CHIF"],
    ["CCONJ", "COCO", "COSUB"],
    ["DET", "DETFS", "DETMS", "DINTFS", "DINTMS"],
    ["X", "MOTINC"],
    ["NOUN", "NFP", "NFS", "NMP", "NMS"],
    ["PRON", "PDEMFP", "PDEMFS", "PDEMMP", "PDEMMS", "PINDFP", "PINDFS",
    "PINDMP", "PINDMS", "PPER1S", "PPER2S", "PPER3FP", "PPER3FS", "PPER3MP",
    "PPER3MS", "PPOBJFP", "PPOBJFS", "PPOBJMP", "PPOBJMS", "PREF", "PREFP",
    "PREFS", "PREL", "PRELFP", "PRELFS", "PRELMP", "PRELMS"],
    ["ADP", "PREP"],
    ["VERB", "VPPFP", "VPPFS", "VPPMP", "VPPMS"],
    ["PROPN", "XFAMIL"],
    ["PUNCT", "YPFOR"]
]
""")

Word Error Rate (WER)

The usual WER metric, which is derived from the Levenshtein distance between the words of the reference and hypothesis (i.e. ground truth and prediction respectively). The output is often presented as a percentage, but it can actually exceed 100%, e.g. if you have a lot of insertions.

Of course, what WER is achievable is depends very heavily on the dataset, and on the language to an extent. On some easy datasets, it can get as low as 1%, and good models on harder datasets can struggle to reach 15%, or even worse in challenging conditions.

The WER is defined as the following (where # means “number of”):

\(\dfrac{\#insertions + \#substitutions + \#deletions}{\#refwords}\)

To understand what exactly is an insertion/subtitution/deletion, you should understand the Levenshtein distance, an edit distance.
Roughly speaking, an insertion is a word your model has predicted but does not exist in the reference, a substitution is a word your model has gotten wrong or spelled incorrectly, and a deletion is a word your model has incorrectly omitted.

A limitation of the WER is that all errors are weighed equally. For example, a typo from “processing” to “procesing” does not meaningfully alter meaning, but an error from “car” to “scar” might drastically alter meaning, yet both are considered a single-word and single-character error. This can result in drastic discrepancies between the WER/CER and human evaluation.

wer_hparams = load_hyperpyyaml("""
wer_stats: !new:speechbrain.utils.metric_stats.ErrorRateStats
""")
wer_hparams["wer_stats"].clear()
wer_hparams["wer_stats"].append(
    ids=list(range(len(refs))),
    predict=hyps,
    target=refs,
)
wer_hparams["wer_stats"].summarize()
{'WER': 15.451152223304122,
 'SER': 90.83899394161924,
 'num_edits': 19042,
 'num_scored_tokens': 123240,
 'num_erroneous_sents': 4948,
 'num_scored_sents': 5447,
 'num_absent_sents': 0,
 'num_ref_sents': 5447,
 'insertions': 1868,
 'deletions': 7886,
 'substitutions': 9288,
 'error_rate': 15.451152223304122}

Character Error Rate (CER)

The typical CER measure, for reference. The CER works the same as the WER, but instead operates at character level (not word or token level).
Ultimately, the CER penalizes various errors differently. Small typos (e.g. missed accents) would result in a full substitution error with the WER, but only result in one character substitution error with the CER. This isn’t necessarily an upside since single-character errors can still alter meaning.

This is slower to run as the edit distance needs to be computed over a comparatively much longer sequence.

cer_hparams = load_hyperpyyaml("""
cer_stats: !new:speechbrain.utils.metric_stats.ErrorRateStats
    split_tokens: True
""")
cer_hparams["cer_stats"].clear()
cer_hparams["cer_stats"].append(
    ids=list(range(len(refs))),
    predict=hyps,
    target=refs,
)
cer_hparams["cer_stats"].summarize()
{'WER': 8.728781317403753,
 'SER': 90.83899394161924,
 'num_edits': 57587,
 'num_scored_tokens': 659737,
 'num_erroneous_sents': 4948,
 'num_scored_sents': 5447,
 'num_absent_sents': 0,
 'num_ref_sents': 5447,
 'insertions': 10426,
 'deletions': 36910,
 'substitutions': 10251,
 'error_rate': 8.728781317403753}

Part-of-speech Error Rate (POSER)

poser_hparams = load_hyperpyyaml("""
wer_stats_dposer: !new:speechbrain.utils.metric_stats.ErrorRateStats

uposer_dict: !apply:speechbrain.utils.dictionaries.SynonymDictionary.from_json_path
    path: ./uposer.json
wer_stats_uposer: !new:speechbrain.utils.metric_stats.ErrorRateStats
    equality_comparator: !ref <uposer_dict>

pos_tagger: !apply:speechbrain.lobes.models.flair.FlairSequenceTagger.from_hf
    source: "qanastek/pos-french"
    save_path: ./pretrained_models/
""")
2024-03-28 16:27:25.399507: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-28 16:27:25.399759: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-28 16:27:25.671596: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-28 16:27:26.262645: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-28 16:27:30.960021: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-03-28 16:28:03,311 SequenceTagger predicts: Dictionary with 69 tags: <unk>, O, DET, NFP, ADJFP, AUX, VPPMS, ADV, PREP, PDEMMS, NMS, COSUB, PINDMS, PPOBJMS, VERB, DETFS, NFS, YPFOR, VPPFS, PUNCT, DETMS, PROPN, ADJMS, PPER3FS, ADJFS, COCO, NMP, PREL, PPER1S, ADJMP, VPPMP, DINTMS, PPER3MS, PPER3MP, PREF, ADJ, DINTFS, CHIF, XFAMIL, PRELFS, SYM, NOUN, MOTINC, PINDFS, PPOBJMP, NUM, PREFP, PDEMFS, VPPFP, PPER3FP
refs_poser = poser_hparams["pos_tagger"](refs)
hyps_poser = poser_hparams["pos_tagger"](hyps)
print(" ".join(refs_poser[0]))
print(" ".join(hyps_poser[0]))
INTJ PREP DET NFS PDEMMS AUX PROPN XFAMIL PREP NMS PREP PREP CHIF CHIF NFP PREP DETFS NFS
PREP DET NFS PDEMMS AUX PROPN XFAMIL PREP NMS PREP PREP CHIF CHIF NFP

dPOSER

Instead of computing WER on input words, we extract (preferably all) the parts-of-speech of the input sentences. The WER is then computed over the sequence of labels.

poser_hparams["wer_stats_dposer"].clear()
poser_hparams["wer_stats_dposer"].append(
    ids=list(range(len(refs))),
    predict=hyps_poser,
    target=refs_poser,
)
poser_hparams["wer_stats_dposer"].summarize()
{'WER': 14.70402051648298,
 'SER': 88.87460987699652,
 'num_edits': 18118,
 'num_scored_tokens': 123218,
 'num_erroneous_sents': 4841,
 'num_scored_sents': 5447,
 'num_absent_sents': 0,
 'num_ref_sents': 5447,
 'insertions': 2064,
 'deletions': 8076,
 'substitutions': 7978,
 'error_rate': 14.70402051648298}

uPOSER

The cited paper proposes a variant (uPOSER) with broad POS categories, in case that the used POS model has very specific categories. This can simply be implemented by using a synonym dictionary that groups up equivalent labels easily.

poser_hparams["wer_stats_uposer"].clear()
poser_hparams["wer_stats_uposer"].append(
    ids=list(range(len(refs))),
    predict=hyps_poser,
    target=refs_poser,
)
poser_hparams["wer_stats_uposer"].summarize()
{'WER': 12.26687659270561,
 'SER': 86.50633376170369,
 'num_edits': 15115,
 'num_scored_tokens': 123218,
 'num_erroneous_sents': 4712,
 'num_scored_sents': 5447,
 'num_absent_sents': 0,
 'num_ref_sents': 5447,
 'insertions': 2089,
 'deletions': 8101,
 'substitutions': 4925,
 'error_rate': 12.26687659270561}

Lemma Error Rate (LER)

Instead of computing the WER over words, we compute the WER over lemmatized words.

%%capture
!spacy download fr_core_news_md
ler_hparams = load_hyperpyyaml("""
ler_model: !apply:speechbrain.lobes.models.spacy.SpacyPipeline.from_name
    name: fr_core_news_md
    exclude: ["tagger", "parser", "ner", "textcat"]

wer_stats_ler: !new:speechbrain.utils.metric_stats.ErrorRateStats
""")
refs_ler = ler_hparams["ler_model"].lemmatize(refs)
hyps_ler = ler_hparams["ler_model"].lemmatize(hyps)
print(" ".join(refs_ler[0]))
print(" ".join(hyps_ler[0]))
bonsoir à tout bienvenue c ' être bfm story en direct jusqu ' à dix neuf heure à le un
à tout bienvenue c ' être bfm story en direct jusqu ' à dix neuf heure
ler_hparams["wer_stats_ler"].clear()
ler_hparams["wer_stats_ler"].append(
    ids=list(range(len(refs))),
    predict=hyps_ler,
    target=refs_ler,
)
ler_hparams["wer_stats_ler"].summarize()
{'WER': 14.426271595988885,
 'SER': 88.61758766293373,
 'num_edits': 19105,
 'num_scored_tokens': 132432,
 'num_erroneous_sents': 4827,
 'num_scored_sents': 5447,
 'num_absent_sents': 0,
 'num_ref_sents': 5447,
 'insertions': 2160,
 'deletions': 10219,
 'substitutions': 6726,
 'error_rate': 14.426271595988885}

Embedding Error Rate (EmbER)

Typical WER calculation, except that we weight the penalty of each word substitution if the words are deemed similar enough. This allows you to reduce the impact of e.g. minor spelling errors that do not alter the meaning much.

Setup for this is slightly more involved but the gist of it is that you need:

  • A regular ErrorRateStats object which you will .append() to,

  • The embeddings that you will be using, e.g. using the FlairEmbeddings wrapper,

  • The EmbER configuration, which will point to the embedding (here binding to ember_embeddings.embed_word),

  • The WeightedErrorRateStats which piggy backs over the base ErrorRateStats and plugs into the EmbER similarity function defined just above.

ember_hparams = load_hyperpyyaml("""
wer_stats: !new:speechbrain.utils.metric_stats.ErrorRateStats

ember_embeddings: !apply:speechbrain.lobes.models.flair.embeddings.FlairEmbeddings.from_hf
    embeddings_class: !name:flair.embeddings.FastTextEmbeddings
    source: facebook/fasttext-fr-vectors
    save_path: ./pretrained_models/

ember_metric: !new:speechbrain.utils.metric_stats.EmbeddingErrorRateSimilarity
    embedding_function: !name:speechbrain.lobes.models.flair.embeddings.FlairEmbeddings.embed_word
        - !ref <ember_embeddings>
    low_similarity_weight: 1.0
    high_similarity_weight: 0.1
    threshold: 0.4

weighted_wer_stats: !new:speechbrain.utils.metric_stats.WeightedErrorRateStats
    base_stats: !ref <wer_stats>
    cost_function: !ref <ember_metric>
    weight_name: ember
""")
ember_hparams["wer_stats"].clear()
ember_hparams["wer_stats"].append(
    ids=list(range(len(refs))),
    predict=hyps,
    target=refs,
)
ember_hparams["weighted_wer_stats"].clear()
ember_hparams["weighted_wer_stats"].summarize()
WARNING:gensim.models.fasttext:could not extract any ngrams from '()', returning origin vector
{'ember_wer': 12.225677015059036,
 'ember_insertions': 1868.0,
 'ember_substitutions': 5541.300000000059,
 'ember_deletions': 7886.0,
 'ember_num_edits': 15295.30000000006}

BERTScore

In a nutshell, BERTScore works by comparing the cosine similarity of all targets and predicted embeddings, as obtained from a BERT-like LM encoder. This works rather well because the embeddings are trained to embed information from their context.

This is best explained by the code and documentation of the metric itself.

bertscore_hparams = load_hyperpyyaml("""
bertscore_model_name: camembert/camembert-large
bertscore_model_device: cuda

bertscore_stats: !new:speechbrain.utils.bertscore.BERTScoreStats
    lm: !new:speechbrain.lobes.models.huggingface_transformers.TextEncoder
        source: !ref <bertscore_model_name>
        save_path: pretrained_models/
        device: !ref <bertscore_model_device>
        num_layers: 8
""")
bertscore_hparams["bertscore_stats"].clear()
bertscore_hparams["bertscore_stats"].append(
    ids=list(range(len(refs))),
    predict=hyps,
    target=refs,
)
bertscore_hparams["bertscore_stats"].summarize()
{'bertscore-recall': tensor(0.9033),
 'bertscore-precision': tensor(0.9237),
 'bertscore-f1': tensor(0.9134)}

Sentence Semantic Distance: SemDist

Estimated using the cosine similarity of a single embedding for every sentence, e.g. obtained by averaging of LM embeddings over all tokens.

Here, lower is better. The score is normalized by x1000 by default for readability.

semdist_hparams = load_hyperpyyaml("""
semdist_model_name: camembert/camembert-large
semdist_model_device: cuda

semdist_stats: !new:speechbrain.utils.semdist.SemDistStats
    lm: !new:speechbrain.lobes.models.huggingface_transformers.TextEncoder
        source: !ref <semdist_model_name>
        save_path: pretrained_models/
        device: !ref <semdist_model_device>
    method: meanpool
""")
semdist_hparams["semdist_stats"].clear()
semdist_hparams["semdist_stats"].append(
    ids=list(range(len(refs))),
    predict=hyps,
    target=refs,
)
semdist_hparams["semdist_stats"].summarize()
{'semdist': 41.13104248046875}
semdist_hparams["semdist_stats"].scores[:5]
[{'key': 0, 'semdist': 11.317432403564453},
 {'key': 1, 'semdist': 14.37997817993164},
 {'key': 2, 'semdist': 8.182466506958008},
 {'key': 3, 'semdist': 7.842123508453369},
 {'key': 4, 'semdist': 13.874173164367676}]

Some comparisons

This was a bit thrown together, if you’ve run everything without running out of RAM congratulations :)

for i in range(10):
    ref = " ".join(refs[i])
    hyp = " ".join(hyps[i])

    print(f"""\
=== REF: {ref}
=== HYP: {hyp}
WER:                  {wer_hparams['wer_stats'].scores[i]['WER']:.3f}%
CER:                  {cer_hparams['cer_stats'].scores[i]['WER']:.3f}%
dPOSER:               {poser_hparams['wer_stats_dposer'].scores[i]['WER']:.3f}%
uPOSER:               {poser_hparams['wer_stats_uposer'].scores[i]['WER']:.3f}%
EmbER:                {ember_hparams['weighted_wer_stats'].scores[i]['WER']:.3f}%
BERTScore recall:     {bertscore_hparams['bertscore_stats'].scores[i]['recall']:.5f}
BERTScore precision:  {bertscore_hparams['bertscore_stats'].scores[i]['precision']:.5f}
SemDist mean (x1000): {semdist_hparams['semdist_stats'].scores[i]['semdist']:.5f}
""")
=== REF: bonsoir à tous bienvenue c' est bfm story en direct jusqu' à dix neuf heures à la une
=== HYP: à tous bienvenue c' est bfm story en direct jusqu' à dix neuf heures
WER:                  22.222%
CER:                  20.000%
dPOSER:               22.222%
uPOSER:               22.222%
EmbER:                22.222%
BERTScore recall:     0.87673
BERTScore precision:  0.96040
SemDist mean (x1000): 11.31743

=== REF: de bfm story ce soir la zone euro va t elle encore vivre un été meurtrier l' allemagne première économie européenne pourrait perdre son triple a la situation se détériore en espagne
=== HYP: bfm story ce soir la zone euro va t elle encore vive été meurtrier allemagne première économie européenne pourrait perdre son triple a la situation se détériore en espagne
WER:                  12.500%
CER:                  5.525%
dPOSER:               15.625%
uPOSER:               15.625%
EmbER:                12.500%
BERTScore recall:     0.91836
BERTScore precision:  0.91983
SemDist mean (x1000): 14.37998

=== REF: pourquoi ces nouvelles tensions nous serons avec un spécialiste de l' espagne et nous serons avec le député socialiste rapporteur du budget en direct de l' assemblée nationale christian eckert
=== HYP: ces nouvelles tensions sont avec un spécialiste de l' espagne et nous serons avec le député socialiste rapporteur du budget de l' assemblée nationale christian eckert
WER:                  16.667%
CER:                  14.062%
dPOSER:               16.667%
uPOSER:               16.667%
EmbER:                13.667%
BERTScore recall:     0.92581
BERTScore precision:  0.96108
SemDist mean (x1000): 8.18247

=== REF: à la une également la syrie et les armes chimiques la russie demande au régime de bachar al assad de ne pas utiliser ces armes
=== HYP: la une également la syrie et les armes chimiques la russie demande au régime de bachar el assad ne pas utiliser ses armes
WER:                  16.000%
CER:                  5.556%
dPOSER:               12.000%
uPOSER:               12.000%
EmbER:                8.800%
BERTScore recall:     0.95685
BERTScore precision:  0.95836
SemDist mean (x1000): 7.84212

=== REF: de quel arsenal dispose l' armée syrienne
=== HYP: quelle arsenal dispose l' armée syrienne
WER:                  28.571%
CER:                  12.195%
dPOSER:               28.571%
uPOSER:               14.286%
EmbER:                28.571%
BERTScore recall:     0.93197
BERTScore precision:  0.93909
SemDist mean (x1000): 13.87417

=== REF: quels dégats pourraient provoquer ces armes chimiques
=== HYP: dégâts pourraient provoquer ses armes chimiques
WER:                  42.857%
CER:                  15.094%
dPOSER:               14.286%
uPOSER:               14.286%
EmbER:                30.000%
BERTScore recall:     0.76464
BERTScore precision:  0.85932
SemDist mean (x1000): 46.58437

=== REF: un spécialiste jean pierre daguzan nous répondra sur le plateau de bfm story et puis
=== HYP: spécialistes ont bien accusant nous répondra sur le plateau de bfm story puis
WER:                  40.000%
CER:                  23.810%
dPOSER:               40.000%
uPOSER:               33.333%
EmbER:                40.000%
BERTScore recall:     0.70336
BERTScore precision:  0.73710
SemDist mean (x1000): 48.69765

=== REF: après la droite populaire la droite humaniste voici la droite forte deux jeunes pousses de l' ump guillaume peltier et geoffroy didier lancent ce nouveau mouvement pourquoi faire ils sont mes invités ce soir
=== HYP: la droite populaire la droite humaniste voici la droite forte deux jeunes pousses de l' ump guillaume peltier geoffroy didier migaud pour quoi faire ils sont mes invités ce soir
WER:                  20.588%
CER:                  17.391%
dPOSER:               23.529%
uPOSER:               17.647%
EmbER:                20.588%
BERTScore recall:     0.88929
BERTScore precision:  0.92400
SemDist mean (x1000): 11.49768

=== REF: et puis c(ette) cette fois ci c' est vraiment la fin la fin de france soir liquidé par le tribunal de commerce nous en parlerons avec son tout dernier rédacteur en chef dominique de montvalon
=== HYP: cette fois ci c' est vraiment la fin à la fin de france soir liquidé par le tribunal de commerce nous en parlerons avec tout dernier rédacteur en chef dominique de montvalon
WER:                  14.286%
CER:                  11.518%
dPOSER:               14.286%
uPOSER:               14.286%
EmbER:                13.889%
BERTScore recall:     0.87325
BERTScore precision:  0.95048
SemDist mean (x1000): 8.85153

=== REF: damien gourlet bonsoir avec vous ce qu' il faut retenir ce soir dans l' actualité l' actualité ce sont encore les incendies en espagne
=== HYP: damien gourlet bonsoir olivier avec vous ce qu' il faut retenir ce soir dans l' actualité actualité se sont encore les incendies en espagne
WER:                  12.500%
CER:                  8.955%
dPOSER:               12.500%
uPOSER:               8.333%
EmbER:                8.400%
BERTScore recall:     0.97822
BERTScore precision:  0.94830
SemDist mean (x1000): 9.74524

Citing SpeechBrain

If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:

@misc{speechbrainV1,
  title={Open-Source Conversational AI with {SpeechBrain} 1.0},
  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
  year={2024},
  eprint={2407.00463},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
  year={2021},
  eprint={2106.04624},
  archivePrefix={arXiv},
  primaryClass={eess.AS},
  note={arXiv:2106.04624}
}