mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Import MultiheadAttention into the core pytorch framework. Users now can import MultiheadAttention directly from torch.nn. See "Attention Is All You Need" for more details related to MultiheadAttention function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18334 Differential Revision: D14577966 Pulled By: zhangguanheng66 fbshipit-source-id: 756c0deff623f3780651d9f9a70ce84516c806d3
1174 lines
35 KiB
Python
1174 lines
35 KiB
Python
import warnings
|
|
import torch
|
|
from . import Linear
|
|
from torch.nn.init import xavier_uniform_
|
|
from torch.nn.init import constant_
|
|
from torch.nn.init import xavier_normal_
|
|
from torch.nn.parameter import Parameter
|
|
from .module import Module
|
|
from .. import functional as F
|
|
from ..._jit_internal import weak_module, weak_script_method
|
|
|
|
|
|
@weak_module
|
|
class Threshold(Module):
|
|
r"""Thresholds each element of the input Tensor.
|
|
|
|
Threshold is defined as:
|
|
|
|
.. math::
|
|
y =
|
|
\begin{cases}
|
|
x, &\text{ if } x > \text{threshold} \\
|
|
\text{value}, &\text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
threshold: The value to threshold at
|
|
value: The value to replace with
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Threshold(0.1, 20)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['threshold', 'value', 'inplace']
|
|
|
|
def __init__(self, threshold, value, inplace=False):
|
|
super(Threshold, self).__init__()
|
|
self.threshold = threshold
|
|
self.value = value
|
|
self.inplace = inplace
|
|
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.threshold(input, self.threshold, self.value, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace' if self.inplace else ''
|
|
return 'threshold={}, value={}{}'.format(
|
|
self.threshold, self.value, inplace_str
|
|
)
|
|
|
|
|
|
@weak_module
|
|
class ReLU(Threshold):
|
|
r"""Applies the rectified linear unit function element-wise:
|
|
|
|
:math:`\text{ReLU}(x)= \max(0, x)`
|
|
|
|
Args:
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/ReLU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReLU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
|
|
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
|
|
|
>>> m = nn.ReLU()
|
|
>>> input = torch.randn(2).unsqueeze(0)
|
|
>>> output = torch.cat((m(input),m(-input)))
|
|
"""
|
|
|
|
def __init__(self, inplace=False):
|
|
super(ReLU, self).__init__(0., 0., inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = 'inplace' if self.inplace else ''
|
|
return inplace_str
|
|
|
|
|
|
@weak_module
|
|
class RReLU(Module):
|
|
r"""Applies the randomized leaky rectified liner unit function, element-wise,
|
|
as described in the paper:
|
|
|
|
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
|
|
|
|
The function is defined as:
|
|
|
|
.. math::
|
|
\text{RReLU}(x) =
|
|
\begin{cases}
|
|
x & \text{if } x \geq 0 \\
|
|
ax & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
where :math:`a` is randomly sampled from uniform distribution
|
|
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
|
|
|
See: https://arxiv.org/pdf/1505.00853.pdf
|
|
|
|
Args:
|
|
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
|
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.RReLU(0.1, 0.3)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
|
https://arxiv.org/abs/1505.00853
|
|
"""
|
|
__constants__ = ['lower', 'upper', 'inplace']
|
|
|
|
def __init__(self, lower=1. / 8, upper=1. / 3, inplace=False):
|
|
super(RReLU, self).__init__()
|
|
self.lower = lower
|
|
self.upper = upper
|
|
self.inplace = inplace
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace' if self.inplace else ''
|
|
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
|
|
|
|
|
|
@weak_module
|
|
class Hardtanh(Module):
|
|
r"""Applies the HardTanh function element-wise
|
|
|
|
HardTanh is defined as:
|
|
|
|
.. math::
|
|
\text{HardTanh}(x) = \begin{cases}
|
|
1 & \text{ if } x > 1 \\
|
|
-1 & \text{ if } x < -1 \\
|
|
x & \text{ otherwise } \\
|
|
\end{cases}
|
|
|
|
The range of the linear region :math:`[-1, 1]` can be adjusted using
|
|
:attr:`min_val` and :attr:`max_val`.
|
|
|
|
Args:
|
|
min_val: minimum value of the linear region range. Default: -1
|
|
max_val: maximum value of the linear region range. Default: 1
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
|
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Hardtanh.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Hardtanh(-2, 2)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['min_val', 'max_val', 'inplace']
|
|
|
|
def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None):
|
|
super(Hardtanh, self).__init__()
|
|
if min_value is not None:
|
|
warnings.warn("keyword argument min_value is deprecated and renamed to min_val")
|
|
min_val = min_value
|
|
if max_value is not None:
|
|
warnings.warn("keyword argument max_value is deprecated and renamed to max_val")
|
|
max_val = max_value
|
|
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
self.inplace = inplace
|
|
assert self.max_val > self.min_val
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace' if self.inplace else ''
|
|
return 'min_val={}, max_val={}{}'.format(
|
|
self.min_val, self.max_val, inplace_str
|
|
)
|
|
|
|
|
|
@weak_module
|
|
class ReLU6(Hardtanh):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
|
|
|
Args:
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/ReLU6.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ReLU6()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def __init__(self, inplace=False):
|
|
super(ReLU6, self).__init__(0., 6., inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = 'inplace' if self.inplace else ''
|
|
return inplace_str
|
|
|
|
|
|
@weak_module
|
|
class Sigmoid(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}
|
|
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Sigmoid.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Sigmoid()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return torch.sigmoid(input)
|
|
|
|
|
|
@weak_module
|
|
class Tanh(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Tanh}(x) = \tanh(x) = \frac{e^x - e^{-x}} {e^x + e^{-x}}
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Tanh.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Tanh()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return torch.tanh(input)
|
|
|
|
|
|
@weak_module
|
|
class ELU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
|
|
|
|
Args:
|
|
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/ELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.ELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['alpha', 'inplace']
|
|
|
|
def __init__(self, alpha=1., inplace=False):
|
|
super(ELU, self).__init__()
|
|
self.alpha = alpha
|
|
self.inplace = inplace
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.elu(input, self.alpha, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace' if self.inplace else ''
|
|
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
|
|
|
|
|
@weak_module
|
|
class CELU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
|
|
|
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
|
|
|
|
Args:
|
|
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/CELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.CELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _`Continuously Differentiable Exponential Linear Units`:
|
|
https://arxiv.org/abs/1704.07483
|
|
"""
|
|
__constants__ = ['alpha', 'inplace']
|
|
|
|
def __init__(self, alpha=1., inplace=False):
|
|
super(CELU, self).__init__()
|
|
self.alpha = alpha
|
|
self.inplace = inplace
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.celu(input, self.alpha, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace' if self.inplace else ''
|
|
return 'alpha={}{}'.format(self.alpha, inplace_str)
|
|
|
|
|
|
@weak_module
|
|
class SELU(Module):
|
|
r"""Applied element-wise, as:
|
|
|
|
.. math::
|
|
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
|
|
|
with :math:`\alpha = 1.6732632423543772848170429916717` and
|
|
:math:`\text{scale} = 1.0507009873554804934193349852946`.
|
|
|
|
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
|
|
|
Args:
|
|
inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/SELU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.SELU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
|
|
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
|
"""
|
|
__constants__ = ['inplace']
|
|
|
|
def __init__(self, inplace=False):
|
|
super(SELU, self).__init__()
|
|
self.inplace = inplace
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.selu(input, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = 'inplace' if self.inplace else ''
|
|
return inplace_str
|
|
|
|
|
|
@weak_module
|
|
class GLU(Module):
|
|
r"""Applies the gated linear unit function
|
|
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
|
of the input matrices and :math:`b` is the second half.
|
|
|
|
Args:
|
|
dim (int): the dimension on which to split the input. Default: -1
|
|
|
|
Shape:
|
|
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.GLU()
|
|
>>> input = torch.randn(4, 2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
|
|
def __init__(self, dim=-1):
|
|
super(GLU, self).__init__()
|
|
self.dim = dim
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.glu(input, self.dim)
|
|
|
|
def extra_repr(self):
|
|
return 'dim={}'.format(self.dim)
|
|
|
|
|
|
@weak_module
|
|
class Hardshrink(Module):
|
|
r"""Applies the hard shrinkage function element-wise:
|
|
|
|
.. math::
|
|
\text{HardShrink}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x > \lambda \\
|
|
x, & \text{ if } x < -\lambda \\
|
|
0, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Hardshrink.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Hardshrink()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['lambd']
|
|
|
|
def __init__(self, lambd=0.5):
|
|
super(Hardshrink, self).__init__()
|
|
self.lambd = lambd
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.hardshrink(input, self.lambd)
|
|
|
|
def extra_repr(self):
|
|
return '{}'.format(self.lambd)
|
|
|
|
|
|
@weak_module
|
|
class LeakyReLU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
|
|
|
|
|
or
|
|
|
|
.. math::
|
|
\text{LeakyRELU}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x \geq 0 \\
|
|
\text{negative\_slope} \times x, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
negative_slope: Controls the angle of the negative slope. Default: 1e-2
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/LeakyReLU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.LeakyReLU(0.1)
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['inplace', 'negative_slope']
|
|
|
|
def __init__(self, negative_slope=1e-2, inplace=False):
|
|
super(LeakyReLU, self).__init__()
|
|
self.negative_slope = negative_slope
|
|
self.inplace = inplace
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
|
|
|
def extra_repr(self):
|
|
inplace_str = ', inplace' if self.inplace else ''
|
|
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
|
|
|
|
|
|
@weak_module
|
|
class LogSigmoid(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/LogSigmoid.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.LogSigmoid()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.logsigmoid(input)
|
|
|
|
|
|
@weak_module
|
|
class Softplus(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
|
|
|
|
SoftPlus is a smooth approximation to the ReLU function and can be used
|
|
to constrain the output of a machine to always be positive.
|
|
|
|
For numerical stability the implementation reverts to the linear function
|
|
for inputs above a certain value.
|
|
|
|
Args:
|
|
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
|
|
threshold: values above this revert to a linear function. Default: 20
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Softplus.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softplus()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['beta', 'threshold']
|
|
|
|
def __init__(self, beta=1, threshold=20):
|
|
super(Softplus, self).__init__()
|
|
self.beta = beta
|
|
self.threshold = threshold
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.softplus(input, self.beta, self.threshold)
|
|
|
|
def extra_repr(self):
|
|
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
|
|
|
|
|
|
@weak_module
|
|
class Softshrink(Module):
|
|
r"""Applies the soft shrinkage function elementwise:
|
|
|
|
.. math::
|
|
\text{SoftShrinkage}(x) =
|
|
\begin{cases}
|
|
x - \lambda, & \text{ if } x > \lambda \\
|
|
x + \lambda, & \text{ if } x < -\lambda \\
|
|
0, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Args:
|
|
lambd: the :math:`\lambda` value for the Softshrink formulation. Default: 0.5
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Softshrink.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softshrink()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['lambd']
|
|
|
|
def __init__(self, lambd=0.5):
|
|
super(Softshrink, self).__init__()
|
|
self.lambd = lambd
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.softshrink(input, self.lambd)
|
|
|
|
def extra_repr(self):
|
|
return str(self.lambd)
|
|
|
|
|
|
@weak_module
|
|
class MultiheadAttention(Module):
|
|
r"""Allows the model to jointly attend to information
|
|
from different representation subspaces.
|
|
See reference: Attention Is All You Need
|
|
|
|
.. math::
|
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
|
|
|
Args:
|
|
embed_dim: total dimension of the model
|
|
num_heads: parallel attention layers, or heads
|
|
|
|
Examples::
|
|
|
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
|
"""
|
|
|
|
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
|
|
super(MultiheadAttention, self).__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
|
self.scaling = self.head_dim ** -0.5
|
|
|
|
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
|
if bias:
|
|
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
|
else:
|
|
self.register_parameter('in_proj_bias', None)
|
|
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
|
|
|
if add_bias_kv:
|
|
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
|
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
|
else:
|
|
self.bias_k = self.bias_v = None
|
|
|
|
self.add_zero_attn = add_zero_attn
|
|
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
xavier_uniform_(self.in_proj_weight[:self.embed_dim, :])
|
|
xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim * 2), :])
|
|
xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :])
|
|
|
|
xavier_uniform_(self.out_proj.weight)
|
|
if self.in_proj_bias is not None:
|
|
constant_(self.in_proj_bias, 0.)
|
|
constant_(self.out_proj.bias, 0.)
|
|
if self.bias_k is not None:
|
|
xavier_normal_(self.bias_k)
|
|
if self.bias_v is not None:
|
|
xavier_normal_(self.bias_v)
|
|
|
|
@weak_script_method
|
|
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
|
need_weights=True, static_kv=False, attn_mask=None):
|
|
"""
|
|
Inputs of forward function
|
|
query: [target length, batch size, embed dim]
|
|
key: [sequence length, batch size, embed dim]
|
|
value: [sequence length, batch size, embed dim]
|
|
key_padding_mask: if True, mask padding based on batch size
|
|
incremental_state: if provided, previous time steps are cashed
|
|
need_weights: output attn_output_weights
|
|
static_kv: key and value are static
|
|
|
|
Outputs of forward function
|
|
attn_output: [target length, batch size, embed dim]
|
|
attn_output_weights: [batch size, target length, sequence length]
|
|
"""
|
|
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
|
kv_same = key.data_ptr() == value.data_ptr()
|
|
|
|
tgt_len, bsz, embed_dim = query.size()
|
|
assert embed_dim == self.embed_dim
|
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
|
assert key.size() == value.size()
|
|
|
|
if incremental_state is not None:
|
|
saved_state = self._get_input_buffer(incremental_state)
|
|
if 'prev_key' in saved_state:
|
|
# previous time steps are cached - no need to recompute
|
|
# key and value if they are static
|
|
if static_kv:
|
|
assert kv_same and not qkv_same
|
|
key = value = None
|
|
else:
|
|
saved_state = None
|
|
|
|
if qkv_same:
|
|
# self-attention
|
|
q, k, v = self._in_proj_qkv(query)
|
|
elif kv_same:
|
|
# encoder-decoder attention
|
|
q = self._in_proj_q(query)
|
|
if key is None:
|
|
assert value is None
|
|
k = v = None
|
|
else:
|
|
k, v = self._in_proj_kv(key)
|
|
else:
|
|
q = self._in_proj_q(query)
|
|
k = self._in_proj_k(key)
|
|
v = self._in_proj_v(value)
|
|
q *= self.scaling
|
|
|
|
if self.bias_k is not None:
|
|
assert self.bias_v is not None
|
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
|
if attn_mask is not None:
|
|
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
|
if key_padding_mask is not None:
|
|
key_padding_mask = torch.cat(
|
|
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
|
|
|
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
if k is not None:
|
|
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
if v is not None:
|
|
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
|
|
if saved_state is not None:
|
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
|
if 'prev_key' in saved_state:
|
|
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
|
if static_kv:
|
|
k = prev_key
|
|
else:
|
|
k = torch.cat((prev_key, k), dim=1)
|
|
if 'prev_value' in saved_state:
|
|
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
|
if static_kv:
|
|
v = prev_value
|
|
else:
|
|
v = torch.cat((prev_value, v), dim=1)
|
|
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
|
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
|
|
|
self._set_input_buffer(incremental_state, saved_state)
|
|
|
|
src_len = k.size(1)
|
|
|
|
if key_padding_mask is not None:
|
|
assert key_padding_mask.size(0) == bsz
|
|
assert key_padding_mask.size(1) == src_len
|
|
|
|
if self.add_zero_attn:
|
|
src_len += 1
|
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
|
if attn_mask is not None:
|
|
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
|
if key_padding_mask is not None:
|
|
key_padding_mask = torch.cat(
|
|
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
|
|
|
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
|
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
|
|
|
if attn_mask is not None:
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
attn_output_weights += attn_mask
|
|
|
|
if key_padding_mask is not None:
|
|
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_output_weights = attn_output_weights.masked_fill(
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
|
float('-inf'),
|
|
)
|
|
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
attn_output_weights = F.softmax(
|
|
attn_output_weights.float(), dim=-1,
|
|
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
|
|
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
|
|
|
attn_output = torch.bmm(attn_output_weights, v)
|
|
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
if need_weights:
|
|
# average attention weights over heads
|
|
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
|
|
else:
|
|
attn_output_weights = None
|
|
|
|
return attn_output, attn_output_weights
|
|
|
|
def _in_proj_qkv(self, query):
|
|
return self._in_proj(query).chunk(3, dim=-1)
|
|
|
|
def _in_proj_kv(self, key):
|
|
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
|
|
|
|
def _in_proj_q(self, query):
|
|
return self._in_proj(query, end=self.embed_dim)
|
|
|
|
def _in_proj_k(self, key):
|
|
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
|
|
|
def _in_proj_v(self, value):
|
|
return self._in_proj(value, start=2 * self.embed_dim)
|
|
|
|
def _in_proj(self, input, start=0, end=None):
|
|
weight = self.in_proj_weight
|
|
bias = self.in_proj_bias
|
|
weight = weight[start:end, :]
|
|
if bias is not None:
|
|
bias = bias[start:end]
|
|
return F.linear(input, weight, bias)
|
|
|
|
|
|
@weak_module
|
|
class PReLU(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
|
|
|
or
|
|
|
|
.. math::
|
|
\text{PReLU}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x \geq 0 \\
|
|
ax, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
|
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
|
a separate :math:`a` is used for each input channel.
|
|
|
|
|
|
.. note::
|
|
weight decay should not be used when learning :math:`a` for good performance.
|
|
|
|
.. note::
|
|
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
|
no channel dim and the number of channels = 1.
|
|
|
|
Args:
|
|
num_parameters (int): number of :math:`a` to learn.
|
|
Although it takes an int as input, there is only two values are legitimate:
|
|
1, or the number of channels at input. Default: 1
|
|
init (float): the initial value of :math:`a`. Default: 0.25
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
|
|
|
|
.. image:: scripts/activation_images/PReLU.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.PReLU()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def __init__(self, num_parameters=1, init=0.25):
|
|
self.num_parameters = num_parameters
|
|
super(PReLU, self).__init__()
|
|
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.prelu(input, self.weight)
|
|
|
|
def extra_repr(self):
|
|
return 'num_parameters={}'.format(self.num_parameters)
|
|
|
|
|
|
@weak_module
|
|
class Softsign(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Softsign.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softsign()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.softsign(input)
|
|
|
|
|
|
@weak_module
|
|
class Tanhshrink(Module):
|
|
r"""Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{Tanhshrink}(x) = x - \text{Tanh}(x)
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(N, *)`, same shape as the input
|
|
|
|
.. image:: scripts/activation_images/Tanhshrink.png
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Tanhshrink()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.tanhshrink(input)
|
|
|
|
|
|
@weak_module
|
|
class Softmin(Module):
|
|
r"""Applies the Softmin function to an n-dimensional input Tensor
|
|
rescaling them so that the elements of the n-dimensional output Tensor
|
|
lie in the range `[0, 1]` and sum to 1.
|
|
|
|
Softmin is defined as:
|
|
|
|
.. math::
|
|
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
|
|
|
Shape:
|
|
- Input: :math:`(*)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(*)`, same shape as the input
|
|
|
|
Arguments:
|
|
dim (int): A dimension along which Softmin will be computed (so every slice
|
|
along dim will sum to 1).
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input, with
|
|
values in the range [0, 1]
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softmin()
|
|
>>> input = torch.randn(2, 3)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
|
|
def __init__(self, dim=None):
|
|
super(Softmin, self).__init__()
|
|
self.dim = dim
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.softmin(input, self.dim, _stacklevel=5)
|
|
|
|
|
|
@weak_module
|
|
class Softmax(Module):
|
|
r"""Applies the Softmax function to an n-dimensional input Tensor
|
|
rescaling them so that the elements of the n-dimensional output Tensor
|
|
lie in the range [0,1] and sum to 1.
|
|
|
|
Softmax is defined as:
|
|
|
|
.. math::
|
|
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
|
|
|
Shape:
|
|
- Input: :math:`(*)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(*)`, same shape as the input
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input with
|
|
values in the range [0, 1]
|
|
|
|
Arguments:
|
|
dim (int): A dimension along which Softmax will be computed (so every slice
|
|
along dim will sum to 1).
|
|
|
|
.. note::
|
|
This module doesn't work directly with NLLLoss,
|
|
which expects the Log to be computed between the Softmax and itself.
|
|
Use `LogSoftmax` instead (it's faster and has better numerical properties).
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softmax()
|
|
>>> input = torch.randn(2, 3)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
|
|
def __init__(self, dim=None):
|
|
super(Softmax, self).__init__()
|
|
self.dim = dim
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
if not hasattr(self, 'dim'):
|
|
self.dim = None
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.softmax(input, self.dim, _stacklevel=5)
|
|
|
|
|
|
@weak_module
|
|
class Softmax2d(Module):
|
|
r"""Applies SoftMax over features to each spatial location.
|
|
|
|
When given an image of ``Channels x Height x Width``, it will
|
|
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input with
|
|
values in the range [0, 1]
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.Softmax2d()
|
|
>>> # you softmax over the 2nd dimension
|
|
>>> input = torch.randn(2, 3, 12, 13)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
|
|
return F.softmax(input, 1, _stacklevel=5)
|
|
|
|
|
|
@weak_module
|
|
class LogSoftmax(Module):
|
|
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
|
input Tensor. The LogSoftmax formulation can be simplified as:
|
|
|
|
.. math::
|
|
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
|
|
|
Shape:
|
|
- Input: :math:`(*)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(*)`, same shape as the input
|
|
|
|
Arguments:
|
|
dim (int): A dimension along which LogSoftmax will be computed.
|
|
|
|
Returns:
|
|
a Tensor of the same dimension and shape as the input with
|
|
values in the range [-inf, 0)
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.LogSoftmax()
|
|
>>> input = torch.randn(2, 3)
|
|
>>> output = m(input)
|
|
"""
|
|
__constants__ = ['dim']
|
|
|
|
def __init__(self, dim=None):
|
|
super(LogSoftmax, self).__init__()
|
|
self.dim = dim
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
if not hasattr(self, 'dim'):
|
|
self.dim = None
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
return F.log_softmax(input, self.dim, _stacklevel=5)
|