Source code for speechbrain.integrations.nlp.flair_embeddings

"""Wrappers for Flair embedding classes

Authors
* Sylvain de Langen 2024
"""

from typing import List, Union

import torch

try:
    import flair
    from flair.data import Sentence
    from flair.embeddings import Embeddings
except ImportError as e:
    raise ImportError(
        f"Failed to import flair: {e}\n"
        f"Please install flair e.g. using `pip install flair`.\n"
        f"For more details, see https://github.com/flairNLP/flair"
    ) from e


[docs] class FlairEmbeddings: """ Simple wrapper for generic Flair embeddings. Arguments --------- embeddings : Embeddings The Flair embeddings object. If you do not have one initialized, use :meth:`~FlairEmbeddings.from_hf` instead. Example ------- >>> from speechbrain.utils.metric_stats import EmbeddingErrorRateSimilarity >>> from speechbrain.utils.metric_stats import WeightedErrorRateStats >>> from speechbrain.utils.metric_stats import ErrorRateStats >>> ember = FlairEmbeddings.from_hf( ... embeddings_class=flair.embeddings.TransformerWordEmbeddings, ... source="google-bert/bert-base-uncased", ... ) >>> ember_metric = EmbeddingErrorRateSimilarity( ... embedding_function=lambda x: FlairEmbeddings.embed_word(ember, x), ... low_similarity_weight=1.0, ... high_similarity_weight=0.1, ... threshold=0.4, ... ) >>> weighted_wer = WeightedErrorRateStats( ... base_stats=ErrorRateStats(), ... cost_function=ember_metric, ... weight_name="ember", ... ) >>> weighted_wer.base_stats.append(["id"], ["hi friend"], ["hi buddy"]) >>> weighted_wer.summarize() {'ember_wer': 16.6..., 'ember_insertions': 1.0, 'ember_substitutions': 0.5, 'ember_deletions': 0.0, 'ember_num_edits': 1.5} """ def __init__(self, embeddings: Embeddings) -> None: self.embeddings = embeddings
[docs] @staticmethod def from_hf(embeddings_class, source, *args, **kwargs) -> "FlairEmbeddings": """Fetches and load flair embeddings. Arguments --------- embeddings_class : class The class to use to initialize the model, e.g. `FastTextEmbeddings`. source : str The location of the model (a directory or HF repo, for instance). *args Extra positional arguments to pass to the flair class constructor **kwargs Extra keyword arguments to pass to the flair class constructor Returns ------- FlairEmbeddings """ return FlairEmbeddings(embeddings_class(source, *args, **kwargs))
[docs] def __call__( self, inputs: Union[List[str], List[List[str]]], pad_tensor: torch.Tensor = torch.zeros((1,)), ) -> torch.Tensor: """Extract embeddings for a batch of sentences. Arguments --------- inputs : list of sentences (str or list of tokens) Sentences to embed, in the form of batches of lists of tokens (list of str) or a str. In the case of token lists, tokens do *not* need to be already tokenized for this specific sequence tagger. However, a token may be considered as a single word. Similarly, out-of-vocabulary handling depends on the underlying embedding class. pad_tensor : torch.Tensor, optional What embedding tensor (of shape `[]`, living on the same device as the embeddings to insert as padding. Returns ------- torch.Tensor Batch of shape `[len(inputs), max_len, embed_size]` """ if isinstance(inputs, str): raise ValueError("Expected a list of sentences, not a single str") sentences = [Sentence(sentence) for sentence in inputs] self.embeddings.embed(sentences) # migrate pad to device & broadcast if it's just a scalar pad_tensor = pad_tensor.to(flair.device) pad_tensor = pad_tensor.broadcast_to( self.embeddings.embedding_length ).unsqueeze(0) sentence_embs = [ torch.stack([token.embedding for token in sentence]) for sentence in sentences ] longest_emb = max(emb.size(0) for emb in sentence_embs) sentence_embs = [ torch.cat( [emb, pad_tensor.repeat(longest_emb - emb.size(0), 1)], dim=0 ) for emb in sentence_embs ] return torch.stack(sentence_embs)
[docs] def embed_word(self, word: str) -> torch.Tensor: """Embeds a single word. Arguments --------- word : str Word to embed. Out-of-vocabulary handling depends on the underlying embedding class. Returns ------- torch.Tensor Embedding for a single word, of shape `[embed_size]` """ return self([word])[0, 0, :]