mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
279 lines
13 KiB
Python
279 lines
13 KiB
Python
from .batchnorm import _BatchNorm
|
|
from .. import functional as F
|
|
|
|
|
|
class _InstanceNorm(_BatchNorm):
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
|
|
track_running_stats=False):
|
|
super(_InstanceNorm, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats)
|
|
|
|
def _check_input_dim(self, input):
|
|
raise NotImplementedError
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
# at version 1: removed running_mean and running_var when
|
|
# track_running_stats=False (default)
|
|
if version is None and not self.track_running_stats:
|
|
running_stats_keys = []
|
|
for name in ('running_mean', 'running_var'):
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
running_stats_keys.append(key)
|
|
if len(running_stats_keys) > 0:
|
|
error_msgs.append(
|
|
'Unexpected running stats buffer(s) {names} for {klass} '
|
|
'with track_running_stats=False. If state_dict is a '
|
|
'checkpoint saved before 0.4.0, this may be expected '
|
|
'because {klass} does not track running stats by default '
|
|
'since 0.4.0. Please remove these keys from state_dict. If '
|
|
'the running stats are actually needed, instead set '
|
|
'track_running_stats=True in {klass} to enable them. See '
|
|
'the documentation of {klass} for details.'
|
|
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
|
|
klass=self.__class__.__name__))
|
|
for key in running_stats_keys:
|
|
state_dict.pop(key)
|
|
|
|
super(_InstanceNorm, self)._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
def forward(self, input):
|
|
self._check_input_dim(input)
|
|
|
|
return F.instance_norm(
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training or not self.track_running_stats, self.momentum, self.eps)
|
|
|
|
|
|
class InstanceNorm1d(_InstanceNorm):
|
|
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
|
|
inputs with optional additional channel dimension) as described in the paper
|
|
`Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension separately
|
|
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
|
|
|
|
By default, this layer uses instance statistics computed from input data in
|
|
both training and evaluation modes.
|
|
|
|
If :attr:`track_running_stats` is set to ``True``, during training this
|
|
layer keeps running estimates of its computed mean and variance, which are
|
|
then used for normalization during evaluation. The running estimates are
|
|
kept with a default :attr:`momentum` of 0.1.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
.. note::
|
|
:class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
|
|
have some subtle differences. :class:`InstanceNorm1d` is applied
|
|
on each channel of channeled data like multidimensional time series, but
|
|
:class:`LayerNorm` is usually applied on entire sample and often in NLP
|
|
tasks. Additionaly, :class:`LayerNorm` applies elementwise affine
|
|
transform, while :class:`InstanceNorm1d` usually don't apply affine
|
|
transform.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters, initialized the same way as done for batch normalization.
|
|
Default: ``False``.
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, L)`
|
|
- Output: :math:`(N, C, L)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.InstanceNorm1d(100)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.InstanceNorm1d(100, affine=True)
|
|
>>> input = torch.randn(20, 100, 40)
|
|
>>> output = m(input)
|
|
|
|
.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
|
|
https://arxiv.org/abs/1607.08022
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() == 2:
|
|
raise ValueError(
|
|
'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
|
|
'This is because InstanceNorm1d reshapes inputs to'
|
|
'(1, N * C, ...) from (N, C,...) and this makes'
|
|
'variances 0.'
|
|
)
|
|
if input.dim() != 3:
|
|
raise ValueError('expected 3D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class InstanceNorm2d(_InstanceNorm):
|
|
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension separately
|
|
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
|
|
|
|
By default, this layer uses instance statistics computed from input data in
|
|
both training and evaluation modes.
|
|
|
|
If :attr:`track_running_stats` is set to ``True``, during training this
|
|
layer keeps running estimates of its computed mean and variance, which are
|
|
then used for normalization during evaluation. The running estimates are
|
|
kept with a default :attr:`momentum` of 0.1.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
.. note::
|
|
:class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
|
|
have some subtle differences. :class:`InstanceNorm2d` is applied
|
|
on each channel of channeled data like RGB images, but
|
|
:class:`LayerNorm` is usually applied on entire sample and often in NLP
|
|
tasks. Additionaly, :class:`LayerNorm` applies elementwise affine
|
|
transform, while :class:`InstanceNorm2d` usually don't apply affine
|
|
transform.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters, initialized the same way as done for batch normalization.
|
|
Default: ``False``.
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.InstanceNorm2d(100)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.InstanceNorm2d(100, affine=True)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> output = m(input)
|
|
|
|
.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
|
|
https://arxiv.org/abs/1607.08022
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError('expected 4D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class InstanceNorm3d(_InstanceNorm):
|
|
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension separately
|
|
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size C (where C is the input size) if :attr:`affine` is ``True``.
|
|
|
|
By default, this layer uses instance statistics computed from input data in
|
|
both training and evaluation modes.
|
|
|
|
If :attr:`track_running_stats` is set to ``True``, during training this
|
|
layer keeps running estimates of its computed mean and variance, which are
|
|
then used for normalization during evaluation. The running estimates are
|
|
kept with a default :attr:`momentum` of 0.1.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
.. note::
|
|
:class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
|
|
have some subtle differences. :class:`InstanceNorm3d` is applied
|
|
on each channel of channeled data like 3D models with RGB color, but
|
|
:class:`LayerNorm` is usually applied on entire sample and often in NLP
|
|
tasks. Additionaly, :class:`LayerNorm` applies elementwise affine
|
|
transform, while :class:`InstanceNorm3d` usually don't apply affine
|
|
transform.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters, initialized the same way as done for batch normalization.
|
|
Default: ``False``.
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.InstanceNorm3d(100)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.InstanceNorm3d(100, affine=True)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
|
|
.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
|
|
https://arxiv.org/abs/1607.08022
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError('expected 5D input (got {}D input)'
|
|
.format(input.dim()))
|