"""Library implementing embedding.
Authors
* Abdelwahab Heba 2020
"""
import torch
import logging
import torch.nn as nn
logger = logging.getLogger(__name__)
[docs]
class Embedding(nn.Module):
"""Computes an embedding x = wx.
Arguments
---------
num_embeddings : int
Size of the dictionary of embeddings.
embedding_dim : int
It is the dim of embedding (i.e, the dimensionality of the output).
consider_as_one_hot : bool
Create non-trainable one-hot vector.
blank_id : int
If consider_as_one_hot == True: consider the embedding as one_hot
and use blank_index as zero one_hot vector.
Example
-------
>>> from speechbrain.nnet.embedding import Embedding
>>> import torch
>>> emb = Embedding(
... num_embeddings=40,
... embedding_dim=39,
... consider_as_one_hot=True,
... blank_id=39
... )
>>> inputs = torch.Tensor([10,5,2,0,39]).long()
>>> output = emb(inputs)
>>> output.shape
torch.Size([5, 39])
>>> output
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.]])
>>> emb = Embedding(num_embeddings=5, embedding_dim=3, consider_as_one_hot=False)
>>> e = emb(torch.LongTensor([[0, 1, 2], [3, 4, 2]]))
>>> e.shape
torch.Size([2, 3, 3])
"""
def __init__(
self,
num_embeddings,
embedding_dim=128,
consider_as_one_hot=False,
blank_id=0,
):
super().__init__()
self.num_embeddings = num_embeddings
self.consider_as_one_hot = consider_as_one_hot
if self.consider_as_one_hot:
self.embedding_dim = self.num_embeddings - 1
else:
self.embedding_dim = embedding_dim
self.blank_id = blank_id
if self.consider_as_one_hot:
# deal with blank_id, the output should be embedding_dim-1 as we consider blank output as zeros one_hot vect
# padding_idx fix the idx row to zeros
self.Embedding = nn.Embedding(
self.num_embeddings,
self.embedding_dim,
padding_idx=self.blank_id,
)
one_hot = torch.eye(self.embedding_dim)
if self.blank_id + 1 != self.num_embeddings:
self.Embedding.weight.data[self.blank_id + 1 :] = one_hot[
self.blank_id :
]
if self.blank_id != 0:
self.Embedding.weight.data[: self.blank_id] = one_hot[
: self.blank_id
]
self.Embedding.weight.requires_grad = False
else:
self.Embedding = nn.Embedding(
self.num_embeddings, self.embedding_dim
)
[docs]
def forward(self, x):
"""Returns the embedding of input tensor.
Arguments
---------
x : torch.Tensor
Input to embed.
"""
# pytorch embedding layer only accept long dtype
return self.Embedding(x.long())