Source code for speechbrain.wordemb.transformer
"""
A convenience wrapper for word embeddings retrieved out of
HuggingFace transformers (e.g. BERT)
Authors
* Artem Ploujnikov 2021
"""
import numpy as np
import torch
from torch import nn
def _last_n_layers(count):
return range(-count, 0)
[docs]
class TransformerWordEmbeddings(nn.Module):
"""A wrapper to retrieve word embeddings out of a pretrained Transformer model
from HuggingFace Transformers (e.g. BERT)
Arguments
---------
model: str|nn.Module
the underlying model instance or the name of the model
to download
tokenizer: str|transformers.tokenization_utils_base.PreTrainedTokenizerBase
a pretrained tokenizer - or the identifier to retrieve
one from HuggingFace
layers: int|list
a list of layer indexes from which to construct an embedding or the number of layers
device: str
a torch device identifier. If provided, the model
will be transferred onto that device
Example
-------
NOTE: Doctests are disabled because the dependency on the
HuggingFace transformer library is optional.
>>> from transformers import AutoTokenizer, AutoModel # doctest: +SKIP
>>> from speechbrain.wordemb.transformer import TransformerWordEmbeddings
>>> model_name = "bert-base-uncased" # doctest: +SKIP
>>> tokenizer = AutoTokenizer.from_pretrained(
... model_name, return_tensors='pt') # doctest: +SKIP
>>> model = AutoModel.from_pretrained(
... model_name,
... output_hidden_states=True) # doctest: +SKIP
>>> word_emb = TransformerWordEmbeddings(
... model=model,
... layers=4,
... tokenizer=tokenizer
... ) # doctest: +SKIP
>>> embedding = word_emb.embedding(
... sentence="THIS IS A TEST SENTENCE",
... word="TEST"
... ) # doctest: +SKIP
>>> embedding[:8] # doctest: +SKIP
tensor([ 3.4332, -3.6702, 0.5152, -1.9301, 0.9197, 2.1628, -0.2841, -0.3549])
>>> embeddings = word_emb.embeddings("This is cool") # doctest: +SKIP
>>> embeddings.shape # doctest: +SKIP
torch.Size([3, 768])
>>> embeddings[:, :3] # doctest: +SKIP
tensor([[-2.9078, 1.2496, 0.7269],
[-0.9940, -0.6960, 1.4350],
[-1.2401, -3.8237, 0.2739]])
>>> sentences = [
... "This is the first test sentence",
... "This is the second test sentence",
... "A quick brown fox jumped over the lazy dog"
... ]
>>> batch_embeddings = word_emb.batch_embeddings(sentences) # doctest: +SKIP
>>> batch_embeddings.shape # doctest: +SKIP
torch.Size([3, 9, 768])
>>> batch_embeddings[:, :2, :3] # doctest: +SKIP
tensor([[[-5.0935, -1.2838, 0.7868],
[-4.6889, -2.1488, 2.1380]],
[[-4.4993, -2.0178, 0.9369],
[-4.1760, -2.4141, 1.9474]],
[[-1.0065, 1.4227, -2.6671],
[-0.3408, -0.6238, 0.1780]]])
"""
MSG_WORD = "'word' should be either a word or the index of a word"
DEFAULT_LAYERS = 4
def __init__(self, model, tokenizer=None, layers=None, device=None):
super().__init__()
if not layers:
layers = self.DEFAULT_LAYERS
layers = _last_n_layers(layers) if isinstance(layers, int) else layers
self.layers = list(layers)
if isinstance(model, str):
if tokenizer is None:
tokenizer = model
model = _get_model(model)
if isinstance(tokenizer, str):
tokenizer = _get_tokenizer(tokenizer)
elif tokenizer is None:
raise ValueError(self.MSG_)
self.model = model
self.tokenizer = tokenizer
if device is not None:
self.device = device
self.model = self.model.to(device)
else:
self.device = self.model.device
[docs]
def forward(self, sentence, word=None):
"""Retrieves a word embedding for the specified word within
a given sentence, if a word is provided, or all word embeddings
if only a sentence is given
Arguments
---------
sentence: str
a sentence
word: str|int
a word or a word's index within the sentence. If a word
is given, and it is encountered multiple times in a
sentence, the first occurrence is used
Returns
-------
emb: torch.Tensor
the word embedding
"""
return (
self.embedding(sentence, word)
if word
else self.embeddings(sentence)
)
[docs]
def embedding(self, sentence, word):
"""Retrieves a word embedding for the specified word within
a given sentence
Arguments
---------
sentence: str
a sentence
word: str|int
a word or a word's index within the sentence. If a word
is given, and it is encountered multiple times in a
sentence, the first occurrence is used
Returns
-------
emb: torch.Tensor
the word embedding
"""
encoded = self.tokenizer.encode_plus(sentence, return_tensors="pt")
with torch.no_grad():
output = self.model(**self._to_device(encoded))
if isinstance(word, str):
idx = self._get_word_idx(sentence, word)
elif isinstance(word, int):
idx = word
else:
raise ValueError(self.MSG_WORD)
states = torch.stack(output.hidden_states)
word_embedding = self._get_word_vector(encoded, states, idx).mean(dim=0)
return word_embedding
[docs]
def embeddings(self, sentence):
"""
Returns the model embeddings for all words
in a sentence
Arguments
---------
sentence: str
a sentence
Returns
-------
emb: torch.Tensor
a tensor of all word embeddings
"""
encoded = self.tokenizer.encode_plus(sentence, return_tensors="pt")
with torch.no_grad():
output = self.model(**self._to_device(encoded))
token_ids_word = torch.tensor(
[
idx
for idx, word_id in enumerate(encoded.word_ids())
if word_id is not None
],
device=self.device,
)
states = torch.stack(output.hidden_states)
return self._get_hidden_states(states, token_ids_word)
[docs]
def batch_embeddings(self, sentences):
"""Returns embeddings for a collection of sentences
Arguments
---------
sentences: List[str]
a list of strings corresponding to a batch of
sentences
Returns
-------
emb: torch.Tensor
a (B x W x E) tensor
B - the batch dimensions (samples)
W - the word dimension
E - the embedding dimension
"""
encoded = self.tokenizer.batch_encode_plus(
sentences, padding=True, return_tensors="pt"
)
with torch.no_grad():
output = self.model(**self._to_device(encoded))
states = torch.stack(output.hidden_states)
return self._get_hidden_states(states)
def _to_device(self, encoded):
return {
key: self._tensor_to_device(value) for key, value in encoded.items()
}
def _tensor_to_device(self, value):
return (
value.to(self.device) if isinstance(value, torch.Tensor) else value
)
def _get_word_idx(self, sent, word):
return sent.split(" ").index(word)
def _get_hidden_states(self, states, token_ids_word=None):
output = states[self.layers].sum(0).squeeze()
if token_ids_word is not None:
output = output[token_ids_word]
else:
output = output[:, 1:-1, :]
return output
def _get_word_vector(self, encoded, states, idx):
token_ids_word = torch.from_numpy(
np.where(np.array(encoded.word_ids()) == idx)[0]
).to(self.device)
return self._get_hidden_states(states, token_ids_word)
[docs]
def to(self, device):
"""Transfers the model to the specified PyTorch device"""
self.device = device
self.model = self.model.to(device)
return self
[docs]
class MissingTransformersError(Exception):
"""Thrown when HuggingFace Transformers is not installed"""
MESSAGE = "This module requires HuggingFace Transformers"
def __init__(self):
super().__init__(self.MESSAGE)
def _get_model(identifier):
"""Tries to retrieve a pretrained model from Huggingface"""
try:
from transformers import AutoModel # noqa
return AutoModel.from_pretrained(identifier, output_hidden_states=True)
except ImportError:
raise MissingTransformersError()
def _get_tokenizer(identifier):
"""Tries to retrieve a pretrained tokenizer from HuggingFace"""
try:
from transformers import AutoTokenizer # noqa
return AutoTokenizer.from_pretrained(identifier)
except ImportError:
raise MissingTransformersError()