speechbrain.utils.metric_stats module
The metric_stats
module provides an abstract class for storing
statistics produced over the course of an experiment and summarizing them.
- Authors:
Peter Plantinga 2020
Mirco Ravanelli 2020
Summary
Classes:
Tracks binary metrics, such as precision, recall, F1, EER, etc. |
|
Computes statistics pertaining to multi-label classification tasks, as well as tasks that can be loosely interpreted as such for the purpose of evaluations |
|
A class for tracking error rates (e.g., WER, PER). |
|
A default class for storing and summarizing arbitrary metrics. |
Functions:
Computes the EER (and its threshold). |
|
Computes the minDCF metric normally used to evaluate speaker verification systems. |
|
Runs metric evaluation if parallel over multiple jobs. |
|
Runs metric evaluation sequentially over the inputs. |
Reference
- class speechbrain.utils.metric_stats.MetricStats(metric, n_jobs=1, batch_eval=True)[source]
Bases:
object
A default class for storing and summarizing arbitrary metrics.
More complex metrics can be created by sub-classing this class.
- Parameters
metric (function) – The function to use to compute the relevant metric. Should take at least two arguments (predictions and targets) and can optionally take the relative lengths of either or both arguments. Not usually used in sub-classes.
batch_eval (bool) – When True it feeds the evaluation metric with the batched input. When False and n_jobs=1, it performs metric evaluation one-by-one in a sequential way. When False and n_jobs>1, the evaluation runs in parallel over the different inputs using joblib.
n_jobs (int) – The number of jobs to use for computing the metric. If this is more than one, every sample is processed individually, otherwise the whole batch is passed at once.
Example
>>> from speechbrain.nnet.losses import l1_loss >>> loss_stats = MetricStats(metric=l1_loss) >>> loss_stats.append( ... ids=["utterance1", "utterance2"], ... predictions=torch.tensor([[0.1, 0.2], [0.2, 0.3]]), ... targets=torch.tensor([[0.1, 0.2], [0.1, 0.2]]), ... reduction="batch", ... ) >>> stats = loss_stats.summarize() >>> stats['average'] 0.050... >>> stats['max_score'] 0.100... >>> stats['max_id'] 'utterance2'
- append(ids, *args, **kwargs)[source]
Store a particular set of metric scores.
- Parameters
ids (list) – List of ids corresponding to utterances.
*args – Arguments to pass to the metric function.
**kwargs – Arguments to pass to the metric function.
- speechbrain.utils.metric_stats.multiprocess_evaluation(metric, predict, target, lengths=None, n_jobs=8)[source]
Runs metric evaluation if parallel over multiple jobs.
- speechbrain.utils.metric_stats.sequence_evaluation(metric, predict, target, lengths=None)[source]
Runs metric evaluation sequentially over the inputs.
- class speechbrain.utils.metric_stats.ErrorRateStats(merge_tokens=False, split_tokens=False, space_token='_')[source]
Bases:
MetricStats
A class for tracking error rates (e.g., WER, PER).
- Parameters
merge_tokens (bool) – Whether to merge the successive tokens (used for e.g., creating words out of character tokens). See
speechbrain.dataio.dataio.merge_char
.split_tokens (bool) – Whether to split tokens (used for e.g. creating characters out of word tokens). See
speechbrain.dataio.dataio.split_word
.space_token (str) – The character to use for boundaries. Used with
merge_tokens
this represents character to split on after merge. Used withsplit_tokens
the sequence is joined with this token in between, and then the whole sequence is split.
Example
>>> cer_stats = ErrorRateStats() >>> i2l = {0: 'a', 1: 'b'} >>> cer_stats.append( ... ids=['utterance1'], ... predict=torch.tensor([[0, 1, 1]]), ... target=torch.tensor([[0, 1, 0]]), ... target_len=torch.ones(1), ... ind2lab=lambda batch: [[i2l[int(x)] for x in seq] for seq in batch], ... ) >>> stats = cer_stats.summarize() >>> stats['WER'] 33.33... >>> stats['insertions'] 0 >>> stats['deletions'] 0 >>> stats['substitutions'] 1
- append(ids, predict, target, predict_len=None, target_len=None, ind2lab=None)[source]
Add stats to the relevant containers.
See MetricStats.append()
- Parameters
ids (list) – List of ids corresponding to utterances.
predict (torch.tensor) – A predicted output, for comparison with the target output
target (torch.tensor) – The correct reference output, for comparison with the prediction.
predict_len (torch.tensor) – The predictions relative lengths, used to undo padding if there is padding present in the predictions.
target_len (torch.tensor) – The target outputs’ relative lengths, used to undo padding if there is padding present in the target.
ind2lab (callable) – Callable that maps from indices to labels, operating on batches, for writing alignments.
- class speechbrain.utils.metric_stats.BinaryMetricStats(positive_label=1)[source]
Bases:
MetricStats
Tracks binary metrics, such as precision, recall, F1, EER, etc.
- append(ids, scores, labels)[source]
Appends scores and labels to internal lists.
Does not compute metrics until time of summary, since automatic thresholds (e.g., EER) need full set of scores.
- Parameters
ids (list) – The string ids for the samples
- summarize(field=None, threshold=None, max_samples=None, beta=1, eps=1e-08)[source]
Compute statistics using a full set of scores.
- Full set of fields:
TP - True Positive
TN - True Negative
FP - False Positive
FN - False Negative
FAR - False Acceptance Rate
FRR - False Rejection Rate
DER - Detection Error Rate (EER if no threshold passed)
threshold - threshold (EER threshold if no threshold passed)
precision - Precision (positive predictive value)
recall - Recall (sensitivity)
F-score - Balance of precision and recall (equal if beta=1)
MCC - Matthews Correlation Coefficient
- Parameters
field (str) – A key for selecting a single statistic. If not provided, a dict with all statistics is returned.
threshold (float) – If no threshold is provided, equal error rate is used.
max_samples (float) – How many samples to keep for postive/negative scores. If no max_samples is provided, all scores are kept. Only effective when threshold is None.
beta (float) – How much to weight precision vs recall in F-score. Default of 1. is equal weight, while higher values weight recall higher, and lower values weight precision higher.
eps (float) – A small value to avoid dividing by zero.
- speechbrain.utils.metric_stats.EER(positive_scores, negative_scores)[source]
Computes the EER (and its threshold).
- Parameters
positive_scores (torch.tensor) – The scores from entries of the same class.
negative_scores (torch.tensor) – The scores from entries of different classes.
Example
>>> positive_scores = torch.tensor([0.6, 0.7, 0.8, 0.5]) >>> negative_scores = torch.tensor([0.4, 0.3, 0.2, 0.1]) >>> val_eer, threshold = EER(positive_scores, negative_scores) >>> val_eer 0.0
- speechbrain.utils.metric_stats.minDCF(positive_scores, negative_scores, c_miss=1.0, c_fa=1.0, p_target=0.01)[source]
Computes the minDCF metric normally used to evaluate speaker verification systems. The min_DCF is the minimum of the following C_det function computed within the defined threshold range:
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
where p_miss is the missing probability and p_fa is the probability of having a false alarm.
- Parameters
positive_scores (torch.tensor) – The scores from entries of the same class.
negative_scores (torch.tensor) – The scores from entries of different classes.
c_miss (float) – Cost assigned to a missing error (default 1.0).
c_fa (float) – Cost assigned to a false alarm (default 1.0).
p_target (float) – Prior probability of having a target (default 0.01).
Example
>>> positive_scores = torch.tensor([0.6, 0.7, 0.8, 0.5]) >>> negative_scores = torch.tensor([0.4, 0.3, 0.2, 0.1]) >>> val_minDCF, threshold = minDCF(positive_scores, negative_scores) >>> val_minDCF 0.0
- class speechbrain.utils.metric_stats.ClassificationStats[source]
Bases:
MetricStats
Computes statistics pertaining to multi-label classification tasks, as well as tasks that can be loosely interpreted as such for the purpose of evaluations
Example
>>> import sys >>> from speechbrain.utils.metric_stats import ClassificationStats >>> cs = ClassificationStats() >>> cs.append( ... ids=["ITEM1", "ITEM2", "ITEM3", "ITEM4"], ... predictions=[ ... "M EY K AH", ... "T EY K", ... "B AE D", ... "M EY K", ... ], ... targets=[ ... "M EY K", ... "T EY K", ... "B AE D", ... "M EY K", ... ], ... categories=[ ... "make", ... "take", ... "bad", ... "make" ... ] ... ) >>> cs.write_stats(sys.stdout) Overall Accuracy: 75% Class-Wise Accuracy ------------------- bad -> B AE D : 1 / 1 (100.00%) make -> M EY K: 1 / 2 (50.00%) take -> T EY K: 1 / 1 (100.00%) Confusion --------- Target: bad -> B AE D -> B AE D : 1 / 1 (100.00%) Target: make -> M EY K -> M EY K : 1 / 2 (50.00%) -> M EY K AH: 1 / 2 (50.00%) Target: take -> T EY K -> T EY K : 1 / 1 (100.00%) >>> summary = cs.summarize() >>> summary['accuracy'] 0.75 >>> summary['classwise_stats'][('bad', 'B AE D')] {'total': 1.0, 'correct': 1.0, 'accuracy': 1.0} >>> summary['classwise_stats'][('make', 'M EY K')] {'total': 2.0, 'correct': 1.0, 'accuracy': 0.5} >>> summary['keys'] [('bad', 'B AE D'), ('make', 'M EY K'), ('take', 'T EY K')] >>> summary['predictions'] ['B AE D', 'M EY K', 'M EY K AH', 'T EY K'] >>> summary['classwise_total'] {('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 2.0, ('take', 'T EY K'): 1.0} >>> summary['classwise_correct'] {('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 1.0, ('take', 'T EY K'): 1.0} >>> summary['classwise_accuracy'] {('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 0.5, ('take', 'T EY K'): 1.0}
- append(ids, predictions, targets, categories=None)[source]
Appends inputs, predictions and targets to internal lists
- Parameters
ids (list) – the string IDs for the samples
predictions (list) – the model’s predictions (human-interpretable, preferably strings)
targets (list) – the ground truths (human-interpretable, preferably strings)
categories (list) – an additional way to classify training samples. If available, the categories will be combined with targets
- summarize(field=None)[source]
Summarize the classification metric scores
The following statistics are computed:
accuracy: the overall accuracy (# correct / # total) confusion_matrix: a dictionary of type
{(target, prediction): num_entries} representing the confusion matrix
- classwise_stats: computes the total number of samples,
the number of correct classifications and accuracy for each class
- keys: all available class keys, which can be either target classes
or (category, target) tuples
- predictions: all available predictions all predicions the model
has made