pytorch/torch/nn/modules/distance.py
Edward Yang eace053398 Move all torch.nn.modules type annotations inline (#38211)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211

Just because the annotations are inline doesn't mean the files type
check; most of the newly annotated files have type errors and I
added exclusions for them in mypy.ini.  The payoff of moving
all of these modules inline is I can delete the relevant code
generation logic for the pyi files (which was added ignore
annotations that weren't actually relevant anymore.)

For the most part the translation was completely mechanical, but there
were two hairy issues.  First, I needed to work around a Python 3.6 and
earlier bug where Generic has a nontrivial metaclass.  This fix is in
torch/jit/__init__.py.  Second, module.py, we need to apply the same
fix for avoiding contravariance checks that the pyi file used to have;
this is done by declaring forward as a variable (rather than a
function), which appears to be sufficient enough to get mypy to not
contravariantly check input arguments.

Because we aren't actually typechecking these modules in most
cases, it is inevitable that some of these type annotations are wrong.
I slavishly copied the old annotations from the pyi files unless there
was an obvious correction I could make.  These annotations will probably
need fixing up later.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21497397

Pulled By: ezyang

fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
2020-06-11 15:59:57 -07:00

76 lines
2.6 KiB
Python

from .module import Module
from .. import functional as F
from torch import Tensor
class PairwiseDistance(Module):
r"""
Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
.. math ::
\Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
Args:
p (real): the norm degree. Default: 2
eps (float, optional): Small value to avoid division by zero.
Default: 1e-6
keepdim (bool, optional): Determines whether or not to keep the vector dimension.
Default: False
Shape:
- Input1: :math:`(N, D)` where `D = vector dimension`
- Input2: :math:`(N, D)`, same shape as the Input1
- Output: :math:`(N)`. If :attr:`keepdim` is ``True``, then :math:`(N, 1)`.
Examples::
>>> pdist = nn.PairwiseDistance(p=2)
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> output = pdist(input1, input2)
"""
__constants__ = ['norm', 'eps', 'keepdim']
norm: float
eps: float
keepdim: bool
def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False) -> None:
super(PairwiseDistance, self).__init__()
self.norm = p
self.eps = eps
self.keepdim = keepdim
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
class CosineSimilarity(Module):
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
.. math ::
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}.
Args:
dim (int, optional): Dimension where cosine similarity is computed. Default: 1
eps (float, optional): Small value to avoid division by zero.
Default: 1e-8
Shape:
- Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`
- Input2: :math:`(\ast_1, D, \ast_2)`, same shape as the Input1
- Output: :math:`(\ast_1, \ast_2)`
Examples::
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
>>> output = cos(input1, input2)
"""
__constants__ = ['dim', 'eps']
dim: int
eps: float
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super(CosineSimilarity, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
return F.cosine_similarity(x1, x2, self.dim, self.eps)