mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
235 lines
8.5 KiB
Python
235 lines
8.5 KiB
Python
import torch
|
|
import numbers
|
|
from torch.nn.parameter import Parameter
|
|
from .module import Module
|
|
from .. import functional as F
|
|
from .. import init
|
|
|
|
|
|
class LocalResponseNorm(Module):
|
|
r"""Applies local response normalization over an input signal composed
|
|
of several input planes, where channels occupy the second dimension.
|
|
Applies normalization across channels.
|
|
|
|
.. math::
|
|
b_{c} = a_{c}\left(k + \frac{\alpha}{n}
|
|
\sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
|
|
|
|
Args:
|
|
size: amount of neighbouring channels used for normalization
|
|
alpha: multiplicative factor. Default: 0.0001
|
|
beta: exponent. Default: 0.75
|
|
k: additive factor. Default: 1
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, *)`
|
|
- Output: :math:`(N, C, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> lrn = nn.LocalResponseNorm(2)
|
|
>>> signal_2d = torch.randn(32, 5, 24, 24)
|
|
>>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
|
|
>>> output_2d = lrn(signal_2d)
|
|
>>> output_4d = lrn(signal_4d)
|
|
|
|
"""
|
|
__constants__ = ['size', 'alpha', 'beta', 'k']
|
|
|
|
def __init__(self, size, alpha=1e-4, beta=0.75, k=1.):
|
|
super(LocalResponseNorm, self).__init__()
|
|
self.size = size
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.k = k
|
|
|
|
def forward(self, input):
|
|
return F.local_response_norm(input, self.size, self.alpha, self.beta,
|
|
self.k)
|
|
|
|
def extra_repr(self):
|
|
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
|
|
|
|
|
class CrossMapLRN2d(Module):
|
|
|
|
def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
|
|
super(CrossMapLRN2d, self).__init__()
|
|
self.size = size
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.k = k
|
|
|
|
def forward(self, input):
|
|
return self._backend.CrossMapLRN2d(self.size, self.alpha, self.beta,
|
|
self.k)(input)
|
|
|
|
def extra_repr(self):
|
|
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
|
|
|
|
|
class LayerNorm(Module):
|
|
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
|
the paper `Layer Normalization`_ .
|
|
|
|
.. math::
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated separately over the last
|
|
certain number dimensions which have to be of the shape specified by
|
|
:attr:`normalized_shape`.
|
|
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
|
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
|
|
|
.. note::
|
|
Unlike Batch Normalization and Instance Normalization, which applies
|
|
scalar scale and bias for each entire channel/plane with the
|
|
:attr:`affine` option, Layer Normalization applies per-element scale and
|
|
bias with :attr:`elementwise_affine`.
|
|
|
|
This layer uses statistics computed from input data in both training and
|
|
evaluation modes.
|
|
|
|
Args:
|
|
normalized_shape (int or list or torch.Size): input shape from an expected input
|
|
of size
|
|
|
|
.. math::
|
|
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
|
\times \ldots \times \text{normalized\_shape}[-1]]
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
elementwise_affine: a boolean value that when set to ``True``, this module
|
|
has learnable per-element affine parameters initialized to ones (for weights)
|
|
and zeros (for biases). Default: ``True``.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *)`
|
|
- Output: :math:`(N, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(20, 5, 10, 10)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.LayerNorm(input.size()[1:])
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
|
|
>>> # Normalize over last two dimensions
|
|
>>> m = nn.LayerNorm([10, 10])
|
|
>>> # Normalize over last dimension of size 10
|
|
>>> m = nn.LayerNorm(10)
|
|
>>> # Activating the module
|
|
>>> output = m(input)
|
|
|
|
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
|
|
"""
|
|
__constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']
|
|
|
|
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
|
super(LayerNorm, self).__init__()
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
normalized_shape = (normalized_shape,)
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
|
self.bias = Parameter(torch.Tensor(*normalized_shape))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
if self.elementwise_affine:
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def forward(self, input):
|
|
return F.layer_norm(
|
|
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
|
|
def extra_repr(self):
|
|
return '{normalized_shape}, eps={eps}, ' \
|
|
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
|
|
|
|
|
class GroupNorm(Module):
|
|
r"""Applies Group Normalization over a mini-batch of inputs as described in
|
|
the paper `Group Normalization`_ .
|
|
|
|
.. math::
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The input channels are separated into :attr:`num_groups` groups, each containing
|
|
``num_channels / num_groups`` channels. The mean and standard-deviation are calculated
|
|
separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
|
|
per-channel affine transform parameter vectors of size :attr:`num_channels` if
|
|
:attr:`affine` is ``True``.
|
|
|
|
This layer uses statistics computed from input data in both training and
|
|
evaluation modes.
|
|
|
|
Args:
|
|
num_groups (int): number of groups to separate the channels into
|
|
num_channels (int): number of channels expected in input
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
affine: a boolean value that when set to ``True``, this module
|
|
has learnable per-channel affine parameters initialized to ones (for weights)
|
|
and zeros (for biases). Default: ``True``.
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
|
|
- Output: :math:`(N, C, *)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(20, 6, 10, 10)
|
|
>>> # Separate 6 channels into 3 groups
|
|
>>> m = nn.GroupNorm(3, 6)
|
|
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
|
|
>>> m = nn.GroupNorm(6, 6)
|
|
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
|
|
>>> m = nn.GroupNorm(1, 6)
|
|
>>> # Activating the module
|
|
>>> output = m(input)
|
|
|
|
.. _`Group Normalization`: https://arxiv.org/abs/1803.08494
|
|
"""
|
|
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine', 'weight',
|
|
'bias']
|
|
|
|
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
|
super(GroupNorm, self).__init__()
|
|
self.num_groups = num_groups
|
|
self.num_channels = num_channels
|
|
self.eps = eps
|
|
self.affine = affine
|
|
if self.affine:
|
|
self.weight = Parameter(torch.Tensor(num_channels))
|
|
self.bias = Parameter(torch.Tensor(num_channels))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
if self.affine:
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def forward(self, input):
|
|
return F.group_norm(
|
|
input, self.num_groups, self.weight, self.bias, self.eps)
|
|
|
|
def extra_repr(self):
|
|
return '{num_groups}, {num_channels}, eps={eps}, ' \
|
|
'affine={affine}'.format(**self.__dict__)
|
|
|
|
|
|
# TODO: ContrastiveNorm2d
|
|
# TODO: DivisiveNorm2d
|
|
# TODO: SubtractiveNorm2d
|