"""
Utilities for word embeddings
Authors
* Artem Ploujnikov 2021
"""
import torch
[docs]
def expand_to_chars(emb, seq, seq_len, word_separator):
"""Expands word embeddings to a sequence of character
embeddings, assigning each character the word embedding
of the word to which it belongs
Arguments
---------
emb: torch.Tensor
a tensor of word embeddings
seq: torch.Tensor
a tensor of character embeddings
seq_len: torch.Tensor
a tensor of character embedding lengths
word_separator: torch.Tensor
the word separator being used
Returns
-------
char_word_emb: torch.Tensor
a combined character + word embedding tensor
Example
-------
>>> import torch
>>> emb = torch.tensor(
... [[[1., 2., 3.],
... [3., 1., 2.],
... [0., 0., 0.]],
... [[1., 3., 2.],
... [3., 2., 1.],
... [2., 3., 1.]]]
... )
>>> seq = torch.tensor(
... [[1, 2, 0, 2, 1, 0],
... [1, 0, 1, 2, 0, 2]]
... )
>>> seq_len = torch.tensor([4, 5])
>>> word_separator = 0
>>> expand_to_chars(emb, seq, seq_len, word_separator)
tensor([[[1., 2., 3.],
[1., 2., 3.],
[0., 0., 0.],
[3., 1., 2.],
[3., 1., 2.],
[0., 0., 0.]],
<BLANKLINE>
[[1., 3., 2.],
[0., 0., 0.],
[3., 2., 1.],
[3., 2., 1.],
[0., 0., 0.],
[2., 3., 1.]]])
"""
word_boundaries = seq == word_separator
words = word_boundaries.cumsum(dim=-1)
# TODO: Find a way to vectorize over the batch axis
char_word_emb = torch.zeros(emb.size(0), seq.size(-1), emb.size(-1)).to(
emb.device
)
seq_len_idx = (seq_len * seq.size(-1)).int()
for idx, (item, item_length) in enumerate(zip(words, seq_len_idx)):
char_word_emb[idx] = emb[idx, item]
char_word_emb[idx, item_length:, :] = 0
char_word_emb[idx, word_boundaries[idx], :] = 0
return char_word_emb