mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
1215 lines
37 KiB
Python
1215 lines
37 KiB
Python
import warnings
|
|
from typing import Tuple, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from .linear import _LinearWithBias
|
|
from torch.nn.init import xavier_uniform_
|
|
from torch.nn.init import constant_
|
|
from torch.nn.init import xavier_normal_
|
|
from torch.nn.parameter import Parameter
|
|
from .module import Module
|
|
from .. import functional as F
|
|
|
|
|
|
class Threshold(Module):
|
|
r"""Thresholds each element of the input Tensor.
|
|
|
|
Threshold is defined as:
|
|
|
|
.. math::
|
|
y =
|
|
\begin{cases}
|
|
x, &\text{ if } x > \text{threshold} \\
|
|
\text{value}, &\text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
threshold: The value to threshold at
|
|
value: The value to replace with
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Threshold(0.1, 20)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['threshold', 'value', 'inplace']
|
|
|
|
threshold: float
|
|
value: float
|
|
inplace: bool
|
|
|
|
def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
|
|
super(Threshold, self).__init__()
|
|
self.threshold = threshold
|
|
self.value = value
|
|
self.inplace = inplace
|
|
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.threshold(input, self.threshold, self.value, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace=True' if self.inplace else ''
|
|
return 'threshold={}, value={}{}'.format(
|
|
self.threshold, self.value, inplace_str
|
|
)
|
|
|
|
|
|
class ReLU(Module):
|
|
r"""Applies the rectified linear unit function element-wise:
|
|
|
|
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
|
|
|
|
Args:
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/ReLU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReLU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
|
|
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
|
|
|
>>> m = nn.ReLU()
|
|
>>> input = torch.randn(2).unsqueeze(0)
|
|
>>> output = torch.cat((m(input),m(-input)))
|
|
"""
|
|
__constants__ = ['inplace']
|
|
inplace: bool
|
|
|
|
def __init__(self, inplace: bool = False):
|
|
super(ReLU, self).__init__()
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.relu(input, inplace=self.inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = 'inplace=True' if self.inplace else ''
|
|
return inplace_str
|
|
|
|
|
|
class RReLU(Module):
|
|
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
|
as described in the paper:
|
|
|
|
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
|
|
|
The function is defined as:
|
|
|
|
.. math::
|
|
\text{RReLU}(x) =
|
|
\begin{cases}
|
|
x & \text{if } x \geq 0 \\
|
|
ax & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
where :math:`a` is randomly sampled from uniform distribution
|
|
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
|
|
|
See: https://arxiv.org/pdf/1505.00853.pdf
|
|
|
|
Args:
|
|
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
|
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.RReLU(0.1, 0.3)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
|
https://arxiv.org/abs/1505.00853
|
|
"""
|
|
__constants__ = ['lower', 'upper', 'inplace']
|
|
|
|
lower: float
|
|
upper: float
|
|
inplace: bool
|
|
|
|
def __init__(
|
|
self,
|
|
lower: float = 1. / 8,
|
|
upper: float = 1. / 3,
|
|
inplace: bool = False
|
|
):
|
|
super(RReLU, self).__init__()
|
|
self.lower = lower
|
|
self.upper = upper
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace=True' if self.inplace else ''
|
|
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
|
|
|
|
|
|
class Hardtanh(Module):
|
|
r"""Applies the HardTanh function element-wise
|
|
|
|
HardTanh is defined as:
|
|
|
|
.. math::
|
|
\text{HardTanh}(x) = \begin{cases}
|
|
1 & \text{ if } x > 1 \\
|
|
-1 & \text{ if } x < -1 \\
|
|
x & \text{ otherwise } \\
|
|
\end{cases}
|
|
|
|
The range of the linear region :math:`[-1, 1]` can be adjusted using
|
|
:attr:`min_val` and :attr:`max_val`.
|
|
|
|
Args:
|
|
min_val: minimum value of the linear region range. Default: -1
|
|
max_val: maximum value of the linear region range. Default: 1
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
|
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Hardtanh.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Hardtanh(-2, 2)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['min_val', 'max_val', 'inplace']
|
|
|
|
min_val: float
|
|
max_val: float
|
|
inplace: bool
|
|
|
|
def __init__(
|
|
self,
|
|
min_val: float = -1.,
|
|
max_val: float = 1.,
|
|
inplace: bool = False,
|
|
min_value: Optional[float] = None,
|
|
max_value: Optional[float] = None
|
|
) -> None:
|
|
super(Hardtanh, self).__init__()
|
|
if min_value is not None:
|
|
warnings.warn("keyword argument min_value is deprecated and rename to min_val")
|
|
min_val = min_value
|
|
if max_value is not None:
|
|
warnings.warn("keyword argument max_value is deprecated and rename to max_val")
|
|
max_val = max_value
|
|
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
self.inplace = inplace
|
|
assert self.max_val > self.min_val
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = ', inplace=True' if self.inplace else ''
|
|
return 'min_val={}, max_val={}{}'.format(
|
|
self.min_val, self.max_val, inplace_str
|
|
)
|
|
|
|
|
|
class ReLU6(Hardtanh):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
|
|
|
Args:
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/ReLU6.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReLU6()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def __init__(self, inplace: bool = False):
|
|
super(ReLU6, self).__init__(0., 6., inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = 'inplace=True' if self.inplace else ''
|
|
return inplace_str
|
|
|
|
|
|
class Sigmoid(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
|
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Sigmoid.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Sigmoid()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return torch.sigmoid(input)
|
|
|
|
|
|
class Hardsigmoid(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Hardsigmoid}(x) = \begin{cases}
|
|
0 & \text{if~} x \le -3, \\
|
|
1 & \text{if~} x \ge +3, \\
|
|
x / 6 + 1 / 2 & \text{otherwise}
|
|
\end{cases}
|
|
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Hardsigmoid()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.hardsigmoid(input)
|
|
|
|
|
|
class Tanh(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Tanh.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Tanh()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return torch.tanh(input)
|
|
|
|
|
|
class Hardswish(Module):
|
|
r"""Applies the hardswish function, element-wise, as described in the paper:
|
|
|
|
`Searching for MobileNetV3`_.
|
|
|
|
.. math::
|
|
\text{Hardswish}(x) = \begin{cases}
|
|
0 & \text{if~} x \le -3, \\
|
|
x & \text{if~} x \ge +3, \\
|
|
x \cdot (x + 3) /6 & \text{otherwise}
|
|
\end{cases}
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Hardswish()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _`Searching for MobileNetV3`:
|
|
https://arxiv.org/abs/1905.02244
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.hardswish(input)
|
|
|
|
|
|
class ELU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
|
|
|
|
Args:
|
|
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/ELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['alpha', 'inplace']
|
|
alpha: float
|
|
inplace: bool
|
|
|
|
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
|
|
super(ELU, self).__init__()
|
|
self.alpha = alpha
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.elu(input, self.alpha, self.inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = ', inplace=True' if self.inplace else ''
|
|
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
|
|
|
|
|
class CELU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
|
|
|
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
|
|
|
|
Args:
|
|
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/CELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.CELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _`Continuously Differentiable Exponential Linear Units`:
|
|
https://arxiv.org/abs/1704.07483
|
|
"""
|
|
__constants__ = ['alpha', 'inplace']
|
|
alpha: float
|
|
inplace: bool
|
|
|
|
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
|
|
super(CELU, self).__init__()
|
|
self.alpha = alpha
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.celu(input, self.alpha, self.inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = ', inplace=True' if self.inplace else ''
|
|
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
|
|
|
|
|
class SELU(Module):
|
|
r"""Applied element-wise, as:
|
|
|
|
.. math::
|
|
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
|
|
|
with :math:`\alpha = 1.6732632423543772848170429916717` and
|
|
:math:`\text{scale} = 1.0507009873554804934193349852946`.
|
|
|
|
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
|
|
|
Args:
|
|
inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/SELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.SELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
|
"""
|
|
__constants__ = ['inplace']
|
|
inplace: bool
|
|
|
|
def __init__(self, inplace: bool = False) -> None:
|
|
super(SELU, self).__init__()
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.selu(input, self.inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = 'inplace=True' if self.inplace else ''
|
|
return inplace_str
|
|
|
|
|
|
class GLU(Module):
|
|
r"""Applies the gated linear unit function
|
|
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
|
of the input matrices and :math:`b` is the second half.
|
|
|
|
Args:
|
|
dim (int): the dimension on which to split the input. Default: -1
|
|
|
|
Shape:
|
|
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.GLU()
|
|
>>> input = torch.randn(4, 2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
dim: int
|
|
|
|
def __init__(self, dim: int = -1) -> None:
|
|
super(GLU, self).__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.glu(input, self.dim)
|
|
|
|
def extra_repr(self) -> str:
|
|
return 'dim={}'.format(self.dim)
|
|
|
|
|
|
class GELU(Module):
|
|
r"""Applies the Gaussian Error Linear Units function:
|
|
|
|
.. math:: \text{GELU}(x) = x * \Phi(x)
|
|
|
|
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/GELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.GELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.gelu(input)
|
|
|
|
|
|
class Hardshrink(Module):
|
|
r"""Applies the hard shrinkage function element-wise:
|
|
|
|
.. math::
|
|
\text{HardShrink}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x > \lambda \\
|
|
x, & \text{ if } x < -\lambda \\
|
|
0, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Hardshrink.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Hardshrink()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['lambd']
|
|
lambd: float
|
|
|
|
def __init__(self, lambd: float = 0.5) -> None:
|
|
super(Hardshrink, self).__init__()
|
|
self.lambd = lambd
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.hardshrink(input, self.lambd)
|
|
|
|
def extra_repr(self) -> str:
|
|
return '{}'.format(self.lambd)
|
|
|
|
|
|
class LeakyReLU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
|
|
|
|
|
or
|
|
|
|
.. math::
|
|
\text{LeakyRELU}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x \geq 0 \\
|
|
\text{negative\_slope} \times x, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
negative_slope: Controls the angle of the negative slope. Default: 1e-2
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/LeakyReLU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.LeakyReLU(0.1)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['inplace', 'negative_slope']
|
|
inplace: bool
|
|
negative_slope: float
|
|
|
|
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
|
|
super(LeakyReLU, self).__init__()
|
|
self.negative_slope = negative_slope
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = ', inplace=True' if self.inplace else ''
|
|
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
|
|
|
|
|
|
class LogSigmoid(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/LogSigmoid.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.LogSigmoid()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.logsigmoid(input)
|
|
|
|
|
|
class Softplus(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
|
|
|
|
SoftPlus is a smooth approximation to the ReLU function and can be used
|
|
to constrain the output of a machine to always be positive.
|
|
|
|
For numerical stability the implementation reverts to the linear function
|
|
when :math:`input \times \beta > threshold`.
|
|
|
|
Args:
|
|
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
|
|
threshold: values above this revert to a linear function. Default: 20
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Softplus.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softplus()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['beta', 'threshold']
|
|
beta: int
|
|
threshold: int
|
|
|
|
def __init__(self, beta: int = 1, threshold: int = 20) -> None:
|
|
super(Softplus, self).__init__()
|
|
self.beta = beta
|
|
self.threshold = threshold
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.softplus(input, self.beta, self.threshold)
|
|
|
|
def extra_repr(self) -> str:
|
|
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
|
|
|
|
|
|
class Softshrink(Module):
|
|
r"""Applies the soft shrinkage function elementwise:
|
|
|
|
.. math::
|
|
\text{SoftShrinkage}(x) =
|
|
\begin{cases}
|
|
x - \lambda, & \text{ if } x > \lambda \\
|
|
x + \lambda, & \text{ if } x < -\lambda \\
|
|
0, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Softshrink.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softshrink()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['lambd']
|
|
lambd: float
|
|
|
|
def __init__(self, lambd: float = 0.5) -> None:
|
|
super(Softshrink, self).__init__()
|
|
self.lambd = lambd
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.softshrink(input, self.lambd)
|
|
|
|
def extra_repr(self) -> str:
|
|
return str(self.lambd)
|
|
|
|
|
|
class MultiheadAttention(Module):
|
|
r"""Allows the model to jointly attend to information
|
|
from different representation subspaces.
|
|
See reference: Attention Is All You Need
|
|
|
|
.. math::
|
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
|
|
|
Args:
|
|
embed_dim: total dimension of the model.
|
|
num_heads: parallel attention heads.
|
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
|
bias: add bias as module parameter. Default: True.
|
|
add_bias_kv: add bias to the key and value sequences at dim=0.
|
|
add_zero_attn: add a new batch of zeros to the key and
|
|
value sequences at dim=1.
|
|
kdim: total number of features in key. Default: None.
|
|
vdim: total number of features in value. Default: None.
|
|
|
|
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
|
query, key, and value have the same number of features.
|
|
|
|
Examples::
|
|
|
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
|
"""
|
|
__annotations__ = {
|
|
'bias_k': torch._jit_internal.Optional[torch.Tensor],
|
|
'bias_v': torch._jit_internal.Optional[torch.Tensor],
|
|
}
|
|
|
|
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
|
super(MultiheadAttention, self).__init__()
|
|
self.embed_dim = embed_dim
|
|
self.kdim = kdim if kdim is not None else embed_dim
|
|
self.vdim = vdim if vdim is not None else embed_dim
|
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
|
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
|
|
|
if self._qkv_same_embed_dim is False:
|
|
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
|
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
|
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
|
self.register_parameter('in_proj_weight', None)
|
|
else:
|
|
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
|
self.register_parameter('q_proj_weight', None)
|
|
self.register_parameter('k_proj_weight', None)
|
|
self.register_parameter('v_proj_weight', None)
|
|
|
|
if bias:
|
|
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
|
else:
|
|
self.register_parameter('in_proj_bias', None)
|
|
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
|
|
|
if add_bias_kv:
|
|
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
|
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
|
else:
|
|
self.bias_k = self.bias_v = None
|
|
|
|
self.add_zero_attn = add_zero_attn
|
|
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
if self._qkv_same_embed_dim:
|
|
xavier_uniform_(self.in_proj_weight)
|
|
else:
|
|
xavier_uniform_(self.q_proj_weight)
|
|
xavier_uniform_(self.k_proj_weight)
|
|
xavier_uniform_(self.v_proj_weight)
|
|
|
|
if self.in_proj_bias is not None:
|
|
constant_(self.in_proj_bias, 0.)
|
|
constant_(self.out_proj.bias, 0.)
|
|
if self.bias_k is not None:
|
|
xavier_normal_(self.bias_k)
|
|
if self.bias_v is not None:
|
|
xavier_normal_(self.bias_v)
|
|
|
|
def __setstate__(self, state):
|
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
|
if '_qkv_same_embed_dim' not in state:
|
|
state['_qkv_same_embed_dim'] = True
|
|
|
|
super(MultiheadAttention, self).__setstate__(state)
|
|
|
|
def forward(self, query, key, value, key_padding_mask=None,
|
|
need_weights=True, attn_mask=None):
|
|
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
|
r"""
|
|
Args:
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
See "Attention Is All You Need" for more details.
|
|
key_padding_mask: if provided, specified padding elements in the key will
|
|
be ignored by the attention. When given a binary mask and a value is True,
|
|
the corresponding value on the attention layer will be ignored. When given
|
|
a byte mask and a value is non-zero, the corresponding value on the attention
|
|
layer will be ignored
|
|
need_weights: output attn_output_weights.
|
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
|
|
Shape:
|
|
- Inputs:
|
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
|
|
- Outputs:
|
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
E is the embedding dimension.
|
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
L is the target sequence length, S is the source sequence length.
|
|
"""
|
|
if not self._qkv_same_embed_dim:
|
|
return F.multi_head_attention_forward(
|
|
query, key, value, self.embed_dim, self.num_heads,
|
|
self.in_proj_weight, self.in_proj_bias,
|
|
self.bias_k, self.bias_v, self.add_zero_attn,
|
|
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
|
attn_mask=attn_mask, use_separate_proj_weight=True,
|
|
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
|
v_proj_weight=self.v_proj_weight)
|
|
else:
|
|
return F.multi_head_attention_forward(
|
|
query, key, value, self.embed_dim, self.num_heads,
|
|
self.in_proj_weight, self.in_proj_bias,
|
|
self.bias_k, self.bias_v, self.add_zero_attn,
|
|
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
|
attn_mask=attn_mask)
|
|
|
|
|
|
class PReLU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
|
|
|
or
|
|
|
|
.. math::
|
|
\text{PReLU}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x \geq 0 \\
|
|
ax, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
|
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
|
a separate :math:`a` is used for each input channel.
|
|
|
|
|
|
.. note::
|
|
weight decay should not be used when learning :math:`a` for good performance.
|
|
|
|
.. note::
|
|
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
|
no channel dim and the number of channels = 1.
|
|
|
|
Args:
|
|
num_parameters (int): number of :math:`a` to learn.
|
|
Although it takes an int as input, there is only two values are legitimate:
|
|
1, or the number of channels at input. Default: 1
|
|
init (float): the initial value of :math:`a`. Default: 0.25
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
|
|
|
|
.. image:: ../scripts/activation_images/PReLU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.PReLU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['num_parameters']
|
|
num_parameters: int
|
|
|
|
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
|
|
self.num_parameters = num_parameters
|
|
super(PReLU, self).__init__()
|
|
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.prelu(input, self.weight)
|
|
|
|
def extra_repr(self) -> str:
|
|
return 'num_parameters={}'.format(self.num_parameters)
|
|
|
|
|
|
class Softsign(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Softsign.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softsign()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.softsign(input)
|
|
|
|
|
|
class Tanhshrink(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Tanhshrink}(x) = x - \tanh(x)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: ../scripts/activation_images/Tanhshrink.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Tanhshrink()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.tanhshrink(input)
|
|
|
|
|
|
class Softmin(Module):
|
|
r"""Applies the Softmin function to an n-dimensional input Tensor
|
|
rescaling them so that the elements of the n-dimensional output Tensor
|
|
lie in the range `[0, 1]` and sum to 1.
|
|
|
|
Softmin is defined as:
|
|
|
|
.. math::
|
|
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
|
|
|
Shape:
|
|
- Input: :math:`(*)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(*)`, same shape as the input
|
|
|
|
Arguments:
|
|
dim (int): A dimension along which Softmin will be computed (so every slice
|
|
along dim will sum to 1).
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input, with
|
|
values in the range [0, 1]
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softmin()
|
|
>>> input = torch.randn(2, 3)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
dim: Optional[int]
|
|
|
|
def __init__(self, dim: Optional[int] = None) -> None:
|
|
super(Softmin, self).__init__()
|
|
self.dim = dim
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
if not hasattr(self, 'dim'):
|
|
self.dim = None
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.softmin(input, self.dim, _stacklevel=5)
|
|
|
|
def extra_repr(self):
|
|
return 'dim={dim}'.format(dim=self.dim)
|
|
|
|
class Softmax(Module):
|
|
r"""Applies the Softmax function to an n-dimensional input Tensor
|
|
rescaling them so that the elements of the n-dimensional output Tensor
|
|
lie in the range [0,1] and sum to 1.
|
|
|
|
Softmax is defined as:
|
|
|
|
.. math::
|
|
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
|
|
|
When the input Tensor is a sparse tensor then the unspecifed
|
|
values are treated as ``-inf``.
|
|
|
|
Shape:
|
|
- Input: :math:`(*)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(*)`, same shape as the input
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input with
|
|
values in the range [0, 1]
|
|
|
|
Arguments:
|
|
dim (int): A dimension along which Softmax will be computed (so every slice
|
|
along dim will sum to 1).
|
|
|
|
.. note::
|
|
This module doesn't work directly with NLLLoss,
|
|
which expects the Log to be computed between the Softmax and itself.
|
|
Use `LogSoftmax` instead (it's faster and has better numerical properties).
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softmax(dim=1)
|
|
>>> input = torch.randn(2, 3)
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
__constants__ = ['dim']
|
|
dim: Optional[int]
|
|
|
|
def __init__(self, dim: Optional[int] = None) -> None:
|
|
super(Softmax, self).__init__()
|
|
self.dim = dim
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
if not hasattr(self, 'dim'):
|
|
self.dim = None
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.softmax(input, self.dim, _stacklevel=5)
|
|
|
|
def extra_repr(self) -> str:
|
|
return 'dim={dim}'.format(dim=self.dim)
|
|
|
|
|
|
class Softmax2d(Module):
|
|
r"""Applies SoftMax over features to each spatial location.
|
|
|
|
When given an image of ``Channels x Height x Width``, it will
|
|
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input with
|
|
values in the range [0, 1]
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softmax2d()
|
|
>>> # you softmax over the 2nd dimension
|
|
>>> input = torch.randn(2, 3, 12, 13)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
|
|
return F.softmax(input, 1, _stacklevel=5)
|
|
|
|
|
|
class LogSoftmax(Module):
|
|
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
|
input Tensor. The LogSoftmax formulation can be simplified as:
|
|
|
|
.. math::
|
|
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
|
|
|
Shape:
|
|
- Input: :math:`(*)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(*)`, same shape as the input
|
|
|
|
Arguments:
|
|
dim (int): A dimension along which LogSoftmax will be computed.
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input with
|
|
values in the range [-inf, 0)
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.LogSoftmax()
|
|
>>> input = torch.randn(2, 3)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
dim: Optional[int]
|
|
|
|
def __init__(self, dim: Optional[int] = None) -> None:
|
|
super(LogSoftmax, self).__init__()
|
|
self.dim = dim
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
if not hasattr(self, 'dim'):
|
|
self.dim = None
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return F.log_softmax(input, self.dim, _stacklevel=5)
|
|
|
|
def extra_repr(self):
|
|
return 'dim={dim}'.format(dim=self.dim)
|