mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18735 ghimport-source-id: d81bef54dafd7167d2451250d7be478d3c013920 Reviewed By: cpuhrsch Differential Revision: D14851415 Pulled By: zou3519 fbshipit-source-id: cea6039e760ad571b90f0a536e420498f34be325
340 lines
16 KiB
Python
340 lines
16 KiB
Python
import torch
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from .module import Module
|
|
from .. import functional as F
|
|
from .. import init
|
|
from torch._jit_internal import weak_module, weak_script_method
|
|
|
|
|
|
@weak_module
|
|
class Embedding(Module):
|
|
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
|
|
This module is often used to store word embeddings and retrieve them using indices.
|
|
The input to the module is a list of indices, and the output is the corresponding
|
|
word embeddings.
|
|
|
|
Args:
|
|
num_embeddings (int): size of the dictionary of embeddings
|
|
embedding_dim (int): the size of each embedding vector
|
|
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
|
(initialized to zeros) whenever it encounters the index.
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
is renormalized to have norm :attr:`max_norm`.
|
|
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
|
the words in the mini-batch. Default ``False``.
|
|
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
|
See Notes for more details regarding sparse gradients.
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
|
initialized from :math:`\mathcal{N}(0, 1)`
|
|
|
|
Shape:
|
|
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
|
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
|
|
|
.. note::
|
|
Keep in mind that only a limited number of optimizers support
|
|
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
|
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
|
|
|
.. note::
|
|
With :attr:`padding_idx` set, the embedding vector at
|
|
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
|
vector can be modified afterwards, e.g., using a customized
|
|
initialization method, and thus changing the vector used to pad the
|
|
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
|
is always zero.
|
|
|
|
Examples::
|
|
|
|
>>> # an Embedding module containing 10 tensors of size 3
|
|
>>> embedding = nn.Embedding(10, 3)
|
|
>>> # a batch of 2 samples of 4 indices each
|
|
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
|
>>> embedding(input)
|
|
tensor([[[-0.0251, -1.6902, 0.7172],
|
|
[-0.6431, 0.0748, 0.6969],
|
|
[ 1.4970, 1.3448, -0.9685],
|
|
[-0.3677, -2.7265, -0.1685]],
|
|
|
|
[[ 1.4970, 1.3448, -0.9685],
|
|
[ 0.4362, -0.4004, 0.9400],
|
|
[-0.6431, 0.0748, 0.6969],
|
|
[ 0.9124, -2.3616, 1.1151]]])
|
|
|
|
|
|
>>> # example with padding_idx
|
|
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
|
>>> input = torch.LongTensor([[0,2,0,5]])
|
|
>>> embedding(input)
|
|
tensor([[[ 0.0000, 0.0000, 0.0000],
|
|
[ 0.1535, -2.0309, 0.9315],
|
|
[ 0.0000, 0.0000, 0.0000],
|
|
[-0.1655, 0.9897, 0.0635]]])
|
|
"""
|
|
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
|
|
'norm_type', 'scale_grad_by_freq', 'sparse', '_weight']
|
|
|
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
|
max_norm=None, norm_type=2., scale_grad_by_freq=False,
|
|
sparse=False, _weight=None):
|
|
super(Embedding, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
if padding_idx is not None:
|
|
if padding_idx > 0:
|
|
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
elif padding_idx < 0:
|
|
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
padding_idx = self.num_embeddings + padding_idx
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
if _weight is None:
|
|
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
self.reset_parameters()
|
|
else:
|
|
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
|
'Shape of weight does not match num_embeddings and embedding_dim'
|
|
self.weight = Parameter(_weight)
|
|
self.sparse = sparse
|
|
|
|
def reset_parameters(self):
|
|
init.normal_(self.weight)
|
|
if self.padding_idx is not None:
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.embedding(
|
|
input, self.weight, self.padding_idx, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
|
|
|
def extra_repr(self):
|
|
s = '{num_embeddings}, {embedding_dim}'
|
|
if self.padding_idx is not None:
|
|
s += ', padding_idx={padding_idx}'
|
|
if self.max_norm is not None:
|
|
s += ', max_norm={max_norm}'
|
|
if self.norm_type != 2:
|
|
s += ', norm_type={norm_type}'
|
|
if self.scale_grad_by_freq is not False:
|
|
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
|
if self.sparse is not False:
|
|
s += ', sparse=True'
|
|
return s.format(**self.__dict__)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
|
|
max_norm=None, norm_type=2., scale_grad_by_freq=False,
|
|
sparse=False):
|
|
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
|
|
|
|
Args:
|
|
embeddings (Tensor): FloatTensor containing weights for the Embedding.
|
|
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
|
|
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
|
|
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
|
|
padding_idx (int, optional): See module initialization documentation.
|
|
max_norm (float, optional): See module initialization documentation.
|
|
norm_type (float, optional): See module initialization documentation. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
|
|
sparse (bool, optional): See module initialization documentation.
|
|
|
|
Examples::
|
|
|
|
>>> # FloatTensor containing pretrained weights
|
|
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
|
|
>>> embedding = nn.Embedding.from_pretrained(weight)
|
|
>>> # Get embeddings for index 1
|
|
>>> input = torch.LongTensor([1])
|
|
>>> embedding(input)
|
|
tensor([[ 4.0000, 5.1000, 6.3000]])
|
|
"""
|
|
assert embeddings.dim() == 2, \
|
|
'Embeddings parameter is expected to be 2-dimensional'
|
|
rows, cols = embeddings.shape
|
|
embedding = cls(
|
|
num_embeddings=rows,
|
|
embedding_dim=cols,
|
|
_weight=embeddings,
|
|
padding_idx=padding_idx,
|
|
max_norm=max_norm,
|
|
norm_type=norm_type,
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
sparse=sparse)
|
|
embedding.weight.requires_grad = not freeze
|
|
return embedding
|
|
|
|
|
|
@weak_module
|
|
class EmbeddingBag(Module):
|
|
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
|
|
intermediate embeddings.
|
|
|
|
For bags of constant length and no :attr:`per_sample_weights`, this class
|
|
|
|
* with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
|
|
* with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
|
|
* with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``.
|
|
|
|
However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
|
|
operations.
|
|
|
|
EmbeddingBag also supports per-sample weights as an argument to the forward
|
|
pass. This scales the output of the Embedding before performing a weighted
|
|
reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
|
|
only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
|
|
:attr:`per_sample_weights`.
|
|
|
|
Args:
|
|
num_embeddings (int): size of the dictionary of embeddings
|
|
embedding_dim (int): the size of each embedding vector
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
is renormalized to have norm :attr:`max_norm`.
|
|
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
|
|
the words in the mini-batch. Default ``False``.
|
|
Note: this option is not supported when ``mode="max"``.
|
|
mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
|
|
``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
|
|
into consideration. ``"mean"`` computes the average of the values
|
|
in the bag, ``"max"`` computes the max value over each bag.
|
|
Default: ``"mean"``
|
|
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
|
|
Notes for more details regarding sparse gradients. Note: this option is not
|
|
supported when ``mode="max"``.
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
|
|
initialized from :math:`\mathcal{N}(0, 1)`.
|
|
|
|
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
|
|
:attr:`per_index_weights` (Tensor, optional)
|
|
|
|
- If :attr:`input` is 2D of shape `(B, N)`,
|
|
|
|
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
|
|
this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
|
|
:attr:`offsets` is ignored and required to be ``None`` in this case.
|
|
|
|
- If :attr:`input` is 1D of shape `(N)`,
|
|
|
|
it will be treated as a concatenation of multiple bags (sequences).
|
|
:attr:`offsets` is required to be a 1D tensor containing the
|
|
starting index positions of each bag in :attr:`input`. Therefore,
|
|
for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
|
|
having ``B`` bags. Empty bags (i.e., having 0-length) will have
|
|
returned vectors filled by zeros.
|
|
|
|
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
|
|
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
|
|
must have exactly the same shape as input and is treated as having the same
|
|
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
|
|
|
|
|
|
Output shape: `(B, embedding_dim)`
|
|
|
|
Examples::
|
|
|
|
>>> # an Embedding module containing 10 tensors of size 3
|
|
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
|
|
>>> # a batch of 2 samples of 4 indices each
|
|
>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
|
|
>>> offsets = torch.LongTensor([0,4])
|
|
>>> embedding_sum(input, offsets)
|
|
tensor([[-0.8861, -5.4350, -0.0523],
|
|
[ 1.1306, -2.5798, -1.0044]])
|
|
"""
|
|
__constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
|
|
'scale_grad_by_freq', 'mode', 'sparse', '_weight']
|
|
|
|
def __init__(self, num_embeddings, embedding_dim,
|
|
max_norm=None, norm_type=2., scale_grad_by_freq=False,
|
|
mode='mean', sparse=False, _weight=None):
|
|
super(EmbeddingBag, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
if _weight is None:
|
|
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
self.reset_parameters()
|
|
else:
|
|
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
|
'Shape of weight does not match num_embeddings and embedding_dim'
|
|
self.weight = Parameter(_weight)
|
|
self.mode = mode
|
|
self.sparse = sparse
|
|
|
|
def reset_parameters(self):
|
|
init.normal_(self.weight)
|
|
|
|
@weak_script_method
|
|
def forward(self, input, offsets=None, per_sample_weights=None):
|
|
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
|
return F.embedding_bag(input, self.weight, offsets,
|
|
self.max_norm, self.norm_type,
|
|
self.scale_grad_by_freq, self.mode, self.sparse,
|
|
per_sample_weights)
|
|
|
|
def extra_repr(self):
|
|
s = '{num_embeddings}, {embedding_dim}'
|
|
if self.max_norm is not None:
|
|
s += ', max_norm={max_norm}'
|
|
if self.norm_type != 2:
|
|
s += ', norm_type={norm_type}'
|
|
if self.scale_grad_by_freq is not False:
|
|
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
|
s += ', mode={mode}'
|
|
return s.format(**self.__dict__)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, embeddings, freeze=True, max_norm=None,
|
|
norm_type=2., scale_grad_by_freq=False,
|
|
mode='mean', sparse=False):
|
|
r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
|
|
|
|
Args:
|
|
embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
|
|
First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
|
|
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
|
|
Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
|
|
max_norm (float, optional): See module initialization documentation. Default: ``None``
|
|
norm_type (float, optional): See module initialization documentation. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
|
|
mode (string, optional): See module initialization documentation. Default: ``"mean"``
|
|
sparse (bool, optional): See module initialization documentation. Default: ``False``.
|
|
|
|
Examples::
|
|
|
|
>>> # FloatTensor containing pretrained weights
|
|
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
|
|
>>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
|
|
>>> # Get embeddings for index 1
|
|
>>> input = torch.LongTensor([[1, 0]])
|
|
>>> embeddingbag(input)
|
|
tensor([[ 2.5000, 3.7000, 4.6500]])
|
|
"""
|
|
assert embeddings.dim() == 2, \
|
|
'Embeddings parameter is expected to be 2-dimensional'
|
|
rows, cols = embeddings.shape
|
|
embeddingbag = cls(
|
|
num_embeddings=rows,
|
|
embedding_dim=cols,
|
|
_weight=embeddings,
|
|
max_norm=max_norm,
|
|
norm_type=norm_type,
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
mode=mode,
|
|
sparse=sparse)
|
|
embeddingbag.weight.requires_grad = not freeze
|
|
return embeddingbag
|