pytorch/torch/nn/modules/sparse.py
2016-10-14 15:05:21 -07:00

61 lines
2.5 KiB
Python

import torch
from torch.autograd import Variable
from .module import Module
class Embedding(Module):
"""A simple lookup table that stores embeddings of a fixed dictionary and size
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings: size of the dictionary of embeddings
embedding_dim: the size of each embedding vector
padding_idx: If given, pads the output with zeros whenever it encounters the index. Default: None
max_norm: If given, will renormalize the embeddings to always have a norm lesser than this Default: None
norm_type: The p of the p-norm to compute for the max_norm option
scale_grad_by_freq: if given, this will scale gradients by the frequency of the words in the dictionary.
Input Shape: [ *, * ] : Input is a 2D mini_batch LongTensor of m x n indices to extract from the Embedding dictionary
Output Shape:[ * , *, * ] : Output shape = m x n x embedding_dim
Examples:
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> print(embedding(input))
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> print(embedding(input))
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
super(Embedding, self).__init__(
weight=torch.Tensor(num_embeddings, embedding_dim)
)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.normal_(0, 1)
if self.padding_idx is not None:
self.weight.data[self.padding_idx].fill_(0)
def forward(self, input):
padding_idx = self.padding_idx
if padding_idx is None:
padding_idx = -1
return self._backend.Embedding(padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq)(input, self.weight)
# TODO: SparseLinear