Source code for speechbrain.utils.semdist

"""Provides a metrics class for the SemDist metric.

Authors
* Sylvain de Langen 2024
"""

from typing import Callable, List, Literal

import torch

from speechbrain.utils.metric_stats import MetricStats


[docs] class BaseSemDistStats(MetricStats): """ Base class to implement the SemDist metric, for the variants that estimate a single cosine similarity per pair of target and predicted texts. The SemDist metrics are described by the paper `Evaluating User Perception of Speech Recognition System Quality with Semantic Distance Metric <https://arxiv.org/abs/2110.05376>`_. Arguments --------- embed_function : Callable[[List[str]], torch.Tensor] Given a list of sentences, return their summarized embedding using the method of your choice (e.g. mean pooling) scale : float, optional The `α` scale applied to the cosine similarity result for clarity. The default is `1000`, in order to match the authors' recommendation. batch_size : int, optional How many pairs of utterances should be considered at once. Higher is faster but may result in OOM. """ def __init__( self, embed_function: Callable[[List[str]], torch.Tensor], scale: float = 1000.0, batch_size: int = 64, ): self.clear() self.embed_function = embed_function self.scale = scale self.batch_size = batch_size
[docs] def clear(self): """Clears the collected metrics""" self.ids = [] self.predictions = [] self.targets = [] self.scores = [] self.summary = {}
[docs] def append(self, ids, predict, target): """ Appends inputs, predictions and targets to internal lists Arguments --------- ids: list the string IDs for the samples predict: list the model's predictions in tokenizable format target: list the ground truths in tokenizable format """ self.ids.extend(ids) self.predictions.extend(predict) self.targets.extend(target)
[docs] def summarize(self, field=None): """Summarize the SemDist metric scores. Performs the actual embedding function call and SemDist calculation. Full set of fields: - `semdist`: The average SemDist over all utterances, multiplied by the scale optionally specified at initialization. Additionally, a `scores` list is populated by this function for each pair of sentences. Each entry of that list is a dict, with the fields: - `key`: the ID of the utterance. - `semdist`: The SemDist of the utterance, multiplied by the scale. Arguments --------- field : str, optional The field to return, if you are only interested in one of them. If specified, a single `float` is returned, otherwise, a dict is. Returns ------- dict from str to float, if `field is None` A dictionary of the fields documented above. float, if `field is not None` The single field selected by `field`. """ with torch.no_grad(): self._update_summary() if field is not None: return self.summary[field] return self.summary
def _update_summary(self): """Performs the actual inference and SemDist estimation, updating the `summary` field. Automatically called by `summarize`.""" semdist_sum = 0.0 for chunk_idx in range(0, len(self.predictions), self.batch_size): ids = self.ids[chunk_idx : chunk_idx + self.batch_size] ref_text = self.targets[chunk_idx : chunk_idx + self.batch_size] hyp_text = self.predictions[chunk_idx : chunk_idx + self.batch_size] ref_emb = self.embed_function(ref_text).cpu() hyp_emb = self.embed_function(hyp_text).cpu() similarity = torch.nn.functional.cosine_similarity( ref_emb, hyp_emb, dim=-1 ) chunk_semdist = (1.0 - similarity) * self.scale for i, utt_id in enumerate(ids): self.scores.append( {"key": utt_id, "semdist": chunk_semdist[i].item()} ) semdist_sum += chunk_semdist.sum() semdist = (semdist_sum / len(self.predictions)).item() self.summary["semdist"] = semdist
[docs] class SemDistStats(BaseSemDistStats): """Computes the SemDist metric with a provided HuggingFace Transformers text encoder. Arguments --------- lm : speechbrain.lobes.models.huggingface_transformers.TextEncoder HF Transformers tokenizer and text encoder wrapper to use as a LM. method : "meanpool" or "cls" - `"meanpool"` (default): Computes the mean of all contextualized embeddings, excluding padding tokens. - `"cls"`: Exclusively uses the first contextualized embedding, which with BERT-like tokenizers is the `[CLS]` token, which is typically intended to capture classification information. *args Extra positional arguments passed to the base constructor. **kwargs Extra keyword arguments passed to the base constructor.""" def __init__( self, lm, method: Literal["meanpool", "cls"] = "meanpool", *args, **kwargs, ): super().__init__(embed_function=self._embed, *args, **kwargs) self.lm = lm self.method = method def _embed(self, sentences: List[str]) -> torch.Tensor: """Computes the LM embedding of a batch of independent sentences, according to the pooling method chosen at initialization. Arguments --------- sentences : list of str List of unprocessed sentences to tokenize and encode. Returns ------- torch.Tensor Embedding of the LM encoder. """ sentences = [" ".join(sent) for sent in sentences] tokens, hidden = self.lm(sentences, return_tokens=True) mask = tokens["attention_mask"].cpu() if self.method == "meanpool": masked_hidden = hidden.cpu() * mask.unsqueeze(-1) nonmasked_counts = torch.sum(mask, dim=-1) # shape: [batch_size] return torch.sum( masked_hidden, dim=-2 ) / nonmasked_counts.unsqueeze(-1) elif self.method == "cls": return hidden[:, 0, :].cpu() # the first token else: raise ValueError( f"Specified SemDist method {self.method} is invalid" )