mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
61 lines
2.5 KiB
Python
61 lines
2.5 KiB
Python
import torch
|
|
from torch.autograd import Variable
|
|
|
|
from .module import Module
|
|
|
|
|
|
class Embedding(Module):
|
|
"""A simple lookup table that stores embeddings of a fixed dictionary and size
|
|
This module is often used to store word embeddings and retrieve them using indices.
|
|
The input to the module is a list of indices, and the output is the corresponding
|
|
word embeddings.
|
|
Args:
|
|
num_embeddings: size of the dictionary of embeddings
|
|
embedding_dim: the size of each embedding vector
|
|
padding_idx: If given, pads the output with zeros whenever it encounters the index. Default: None
|
|
max_norm: If given, will renormalize the embeddings to always have a norm lesser than this Default: None
|
|
norm_type: The p of the p-norm to compute for the max_norm option
|
|
scale_grad_by_freq: if given, this will scale gradients by the frequency of the words in the dictionary.
|
|
Input Shape: [ *, * ] : Input is a 2D mini_batch LongTensor of m x n indices to extract from the Embedding dictionary
|
|
Output Shape:[ * , *, * ] : Output shape = m x n x embedding_dim
|
|
Examples:
|
|
>>> # an Embedding module containing 10 tensors of size 3
|
|
>>> embedding = nn.Embedding(10, 3)
|
|
>>> # a batch of 2 samples of 4 indices each
|
|
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
|
>>> print(embedding(input))
|
|
>>> # example with padding_idx
|
|
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
|
>>> input = torch.LongTensor([[0,2,0,5]])
|
|
>>> print(embedding(input))
|
|
"""
|
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
|
max_norm=None, norm_type=2, scale_grad_by_freq=False):
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
|
|
super(Embedding, self).__init__(
|
|
weight=torch.Tensor(num_embeddings, embedding_dim)
|
|
)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
self.weight.data.normal_(0, 1)
|
|
if self.padding_idx is not None:
|
|
self.weight.data[self.padding_idx].fill_(0)
|
|
|
|
def forward(self, input):
|
|
padding_idx = self.padding_idx
|
|
if padding_idx is None:
|
|
padding_idx = -1
|
|
return self._backend.Embedding(padding_idx, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq)(input, self.weight)
|
|
|
|
|
|
# TODO: SparseLinear
|
|
|