mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
363 lines
18 KiB
Python
363 lines
18 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from .module import Module
|
|
from .. import functional as F
|
|
from .. import init
|
|
|
|
|
|
class Embedding(Module):
|
|
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
|
|
This module is often used to store word embeddings and retrieve them using indices.
|
|
The input to the module is a list of indices, and the output is the corresponding
|
|
word embeddings.
|
|
|
|
Args:
|
|
num_embeddings (int): size of the dictionary of embeddings
|
|
embedding_dim (int): the size of each embedding vector
|
|
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
|
(initialized to zeros) whenever it encounters the index.
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
is renormalized to have norm :attr:`max_norm`.
|
|
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
|
the words in the mini-batch. Default ``False``.
|
|
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
|
See Notes for more details regarding sparse gradients.
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
|
initialized from :math:`\mathcal{N}(0, 1)`
|
|
|
|
Shape:
|
|
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
|
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
|
|
|
.. note::
|
|
Keep in mind that only a limited number of optimizers support
|
|
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
|
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
|
|
|
.. note::
|
|
With :attr:`padding_idx` set, the embedding vector at
|
|
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
|
vector can be modified afterwards, e.g., using a customized
|
|
initialization method, and thus changing the vector used to pad the
|
|
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
|
is always zero.
|
|
|
|
Examples::
|
|
|
|
>>> # an Embedding module containing 10 tensors of size 3
|
|
>>> embedding = nn.Embedding(10, 3)
|
|
>>> # a batch of 2 samples of 4 indices each
|
|
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
|
>>> embedding(input)
|
|
tensor([[[-0.0251, -1.6902, 0.7172],
|
|
[-0.6431, 0.0748, 0.6969],
|
|
[ 1.4970, 1.3448, -0.9685],
|
|
[-0.3677, -2.7265, -0.1685]],
|
|
|
|
[[ 1.4970, 1.3448, -0.9685],
|
|
[ 0.4362, -0.4004, 0.9400],
|
|
[-0.6431, 0.0748, 0.6969],
|
|
[ 0.9124, -2.3616, 1.1151]]])
|
|
|
|
|
|
>>> # example with padding_idx
|
|
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
|
>>> input = torch.LongTensor([[0,2,0,5]])
|
|
>>> embedding(input)
|
|
tensor([[[ 0.0000, 0.0000, 0.0000],
|
|
[ 0.1535, -2.0309, 0.9315],
|
|
[ 0.0000, 0.0000, 0.0000],
|
|
[-0.1655, 0.9897, 0.0635]]])
|
|
"""
|
|
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
|
|
'norm_type', 'scale_grad_by_freq', 'sparse']
|
|
|
|
num_embeddings: int
|
|
embedding_dim: int
|
|
padding_idx: int
|
|
max_norm: float
|
|
norm_type: float
|
|
scale_grad_by_freq: bool
|
|
weight: Tensor
|
|
sparse: bool
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
|
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
|
|
super(Embedding, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
if padding_idx is not None:
|
|
if padding_idx > 0:
|
|
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
elif padding_idx < 0:
|
|
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
padding_idx = self.num_embeddings + padding_idx
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
if _weight is None:
|
|
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
self.reset_parameters()
|
|
else:
|
|
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
|
'Shape of weight does not match num_embeddings and embedding_dim'
|
|
self.weight = Parameter(_weight)
|
|
self.sparse = sparse
|
|
|
|
def reset_parameters(self) -> None:
|
|
init.normal_(self.weight)
|
|
if self.padding_idx is not None:
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.embedding(
|
|
input, self.weight, self.padding_idx, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
|
|
|
def extra_repr(self) -> str:
|
|
s = '{num_embeddings}, {embedding_dim}'
|
|
if self.padding_idx is not None:
|
|
s += ', padding_idx={padding_idx}'
|
|
if self.max_norm is not None:
|
|
s += ', max_norm={max_norm}'
|
|
if self.norm_type != 2:
|
|
s += ', norm_type={norm_type}'
|
|
if self.scale_grad_by_freq is not False:
|
|
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
|
if self.sparse is not False:
|
|
s += ', sparse=True'
|
|
return s.format(**self.__dict__)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
|
|
max_norm=None, norm_type=2., scale_grad_by_freq=False,
|
|
sparse=False):
|
|
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
|
|
|
|
Args:
|
|
embeddings (Tensor): FloatTensor containing weights for the Embedding.
|
|
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
|
|
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
|
|
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
|
|
padding_idx (int, optional): See module initialization documentation.
|
|
max_norm (float, optional): See module initialization documentation.
|
|
norm_type (float, optional): See module initialization documentation. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
|
|
sparse (bool, optional): See module initialization documentation.
|
|
|
|
Examples::
|
|
|
|
>>> # FloatTensor containing pretrained weights
|
|
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
|
|
>>> embedding = nn.Embedding.from_pretrained(weight)
|
|
>>> # Get embeddings for index 1
|
|
>>> input = torch.LongTensor([1])
|
|
>>> embedding(input)
|
|
tensor([[ 4.0000, 5.1000, 6.3000]])
|
|
"""
|
|
assert embeddings.dim() == 2, \
|
|
'Embeddings parameter is expected to be 2-dimensional'
|
|
rows, cols = embeddings.shape
|
|
embedding = cls(
|
|
num_embeddings=rows,
|
|
embedding_dim=cols,
|
|
_weight=embeddings,
|
|
padding_idx=padding_idx,
|
|
max_norm=max_norm,
|
|
norm_type=norm_type,
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
sparse=sparse)
|
|
embedding.weight.requires_grad = not freeze
|
|
return embedding
|
|
|
|
|
|
class EmbeddingBag(Module):
|
|
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
|
|
intermediate embeddings.
|
|
|
|
For bags of constant length and no :attr:`per_sample_weights`, this class
|
|
|
|
* with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
|
|
* with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
|
|
* with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``.
|
|
|
|
However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
|
|
operations.
|
|
|
|
EmbeddingBag also supports per-sample weights as an argument to the forward
|
|
pass. This scales the output of the Embedding before performing a weighted
|
|
reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
|
|
only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
|
|
:attr:`per_sample_weights`.
|
|
|
|
Args:
|
|
num_embeddings (int): size of the dictionary of embeddings
|
|
embedding_dim (int): the size of each embedding vector
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
is renormalized to have norm :attr:`max_norm`.
|
|
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
|
|
the words in the mini-batch. Default ``False``.
|
|
Note: this option is not supported when ``mode="max"``.
|
|
mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
|
|
``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
|
|
into consideration. ``"mean"`` computes the average of the values
|
|
in the bag, ``"max"`` computes the max value over each bag.
|
|
Default: ``"mean"``
|
|
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
|
|
Notes for more details regarding sparse gradients. Note: this option is not
|
|
supported when ``mode="max"``.
|
|
include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element
|
|
is equivalent to the size of `indices`. This matches the CSR format. Note:
|
|
this option is currently only supported when ``mode="sum"``.
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
|
|
initialized from :math:`\mathcal{N}(0, 1)`.
|
|
|
|
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
|
|
:attr:`per_index_weights` (Tensor, optional)
|
|
|
|
- If :attr:`input` is 2D of shape `(B, N)`,
|
|
|
|
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
|
|
this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
|
|
:attr:`offsets` is ignored and required to be ``None`` in this case.
|
|
|
|
- If :attr:`input` is 1D of shape `(N)`,
|
|
|
|
it will be treated as a concatenation of multiple bags (sequences).
|
|
:attr:`offsets` is required to be a 1D tensor containing the
|
|
starting index positions of each bag in :attr:`input`. Therefore,
|
|
for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
|
|
having ``B`` bags. Empty bags (i.e., having 0-length) will have
|
|
returned vectors filled by zeros.
|
|
|
|
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
|
|
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
|
|
must have exactly the same shape as input and is treated as having the same
|
|
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
|
|
|
|
|
|
Output shape: `(B, embedding_dim)`
|
|
|
|
Examples::
|
|
|
|
>>> # an Embedding module containing 10 tensors of size 3
|
|
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
|
|
>>> # a batch of 2 samples of 4 indices each
|
|
>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
|
|
>>> offsets = torch.LongTensor([0,4])
|
|
>>> embedding_sum(input, offsets)
|
|
tensor([[-0.8861, -5.4350, -0.0523],
|
|
[ 1.1306, -2.5798, -1.0044]])
|
|
"""
|
|
__constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type',
|
|
'scale_grad_by_freq', 'mode', 'sparse', 'include_last_offset']
|
|
|
|
num_embeddings: int
|
|
embedding_dim: int
|
|
max_norm: float
|
|
norm_type: float
|
|
scale_grad_by_freq: bool
|
|
weight: Tensor
|
|
mode: str
|
|
sparse: bool
|
|
include_last_offset: bool
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int,
|
|
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
|
mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
|
|
include_last_offset: bool = False) -> None:
|
|
super(EmbeddingBag, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
if _weight is None:
|
|
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
self.reset_parameters()
|
|
else:
|
|
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
|
'Shape of weight does not match num_embeddings and embedding_dim'
|
|
self.weight = Parameter(_weight)
|
|
self.mode = mode
|
|
self.sparse = sparse
|
|
self.include_last_offset = include_last_offset
|
|
|
|
def reset_parameters(self) -> None:
|
|
init.normal_(self.weight)
|
|
|
|
def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
|
|
return F.embedding_bag(input, self.weight, offsets,
|
|
self.max_norm, self.norm_type,
|
|
self.scale_grad_by_freq, self.mode, self.sparse,
|
|
per_sample_weights, self.include_last_offset)
|
|
|
|
def extra_repr(self) -> str:
|
|
s = '{num_embeddings}, {embedding_dim}'
|
|
if self.max_norm is not None:
|
|
s += ', max_norm={max_norm}'
|
|
if self.norm_type != 2:
|
|
s += ', norm_type={norm_type}'
|
|
if self.scale_grad_by_freq is not False:
|
|
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
|
s += ', mode={mode}'
|
|
return s.format(**self.__dict__)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Optional[float] = None,
|
|
norm_type: float = 2., scale_grad_by_freq: bool = False,
|
|
mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False) -> 'EmbeddingBag':
|
|
r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
|
|
|
|
Args:
|
|
embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
|
|
First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
|
|
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
|
|
Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
|
|
max_norm (float, optional): See module initialization documentation. Default: ``None``
|
|
norm_type (float, optional): See module initialization documentation. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
|
|
mode (string, optional): See module initialization documentation. Default: ``"mean"``
|
|
sparse (bool, optional): See module initialization documentation. Default: ``False``.
|
|
include_last_offset (bool, optional): See module initialization documentation. Default: ``False``.
|
|
|
|
Examples::
|
|
|
|
>>> # FloatTensor containing pretrained weights
|
|
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
|
|
>>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
|
|
>>> # Get embeddings for index 1
|
|
>>> input = torch.LongTensor([[1, 0]])
|
|
>>> embeddingbag(input)
|
|
tensor([[ 2.5000, 3.7000, 4.6500]])
|
|
"""
|
|
assert embeddings.dim() == 2, \
|
|
'Embeddings parameter is expected to be 2-dimensional'
|
|
rows, cols = embeddings.shape
|
|
embeddingbag = cls(
|
|
num_embeddings=rows,
|
|
embedding_dim=cols,
|
|
_weight=embeddings,
|
|
max_norm=max_norm,
|
|
norm_type=norm_type,
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
mode=mode,
|
|
sparse=sparse,
|
|
include_last_offset=include_last_offset)
|
|
embeddingbag.weight.requires_grad = not freeze
|
|
return embeddingbag
|