import torch from torch.autograd import Variable from .module import Module class Embedding(Module): """A simple lookup table that stores embeddings of a fixed dictionary and size This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings. Args: num_embeddings: size of the dictionary of embeddings embedding_dim: the size of each embedding vector padding_idx: If given, pads the output with zeros whenever it encounters the index. Default: None max_norm: If given, will renormalize the embeddings to always have a norm lesser than this Default: None norm_type: The p of the p-norm to compute for the max_norm option scale_grad_by_freq: if given, this will scale gradients by the frequency of the words in the dictionary. Input Shape: [ *, * ] : Input is a 2D mini_batch LongTensor of m x n indices to extract from the Embedding dictionary Output Shape:[ * , *, * ] : Output shape = m x n x embedding_dim Examples: >>> # an Embedding module containing 10 tensors of size 3 >>> embedding = nn.Embedding(10, 3) >>> # a batch of 2 samples of 4 indices each >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) >>> print(embedding(input)) >>> # example with padding_idx >>> embedding = nn.Embedding(10, 3, padding_idx=0) >>> input = torch.LongTensor([[0,2,0,5]]) >>> print(embedding(input)) """ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False): self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq super(Embedding, self).__init__( weight=torch.Tensor(num_embeddings, embedding_dim) ) self.reset_parameters() def reset_parameters(self): self.weight.data.normal_(0, 1) if self.padding_idx is not None: self.weight.data[self.padding_idx].fill_(0) def forward(self, input): padding_idx = self.padding_idx if padding_idx is None: padding_idx = -1 return self._backend.Embedding(padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq)(input, self.weight) # TODO: SparseLinear