mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
287 lines
13 KiB
Python
287 lines
13 KiB
Python
from .batchnorm import _NormBase
|
|
from .. import functional as F
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
class _InstanceNorm(_NormBase):
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-5,
|
|
momentum: float = 0.1,
|
|
affine: bool = False,
|
|
track_running_stats: bool = False
|
|
) -> None:
|
|
super(_InstanceNorm, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats)
|
|
|
|
def _check_input_dim(self, input):
|
|
raise NotImplementedError
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
# at version 1: removed running_mean and running_var when
|
|
# track_running_stats=False (default)
|
|
if version is None and not self.track_running_stats:
|
|
running_stats_keys = []
|
|
for name in ('running_mean', 'running_var'):
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
running_stats_keys.append(key)
|
|
if len(running_stats_keys) > 0:
|
|
error_msgs.append(
|
|
'Unexpected running stats buffer(s) {names} for {klass} '
|
|
'with track_running_stats=False. If state_dict is a '
|
|
'checkpoint saved before 0.4.0, this may be expected '
|
|
'because {klass} does not track running stats by default '
|
|
'since 0.4.0. Please remove these keys from state_dict. If '
|
|
'the running stats are actually needed, instead set '
|
|
'track_running_stats=True in {klass} to enable them. See '
|
|
'the documentation of {klass} for details.'
|
|
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
|
|
klass=self.__class__.__name__))
|
|
for key in running_stats_keys:
|
|
state_dict.pop(key)
|
|
|
|
super(_InstanceNorm, self)._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
self._check_input_dim(input)
|
|
|
|
return F.instance_norm(
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training or not self.track_running_stats, self.momentum, self.eps)
|
|
|
|
|
|
class InstanceNorm1d(_InstanceNorm):
|
|
r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
|
|
inputs with optional additional channel dimension) as described in the paper
|
|
`Instance Normalization: The Missing Ingredient for Fast Stylization
|
|
<https://arxiv.org/abs/1607.08022>`__.
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension separately
|
|
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
|
|
The standard-deviation is calculated via the biased estimator, equivalent to
|
|
`torch.var(input, unbiased=False)`.
|
|
|
|
By default, this layer uses instance statistics computed from input data in
|
|
both training and evaluation modes.
|
|
|
|
If :attr:`track_running_stats` is set to ``True``, during training this
|
|
layer keeps running estimates of its computed mean and variance, which are
|
|
then used for normalization during evaluation. The running estimates are
|
|
kept with a default :attr:`momentum` of 0.1.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
.. note::
|
|
:class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
|
|
have some subtle differences. :class:`InstanceNorm1d` is applied
|
|
on each channel of channeled data like multidimensional time series, but
|
|
:class:`LayerNorm` is usually applied on entire sample and often in NLP
|
|
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
|
|
transform, while :class:`InstanceNorm1d` usually don't apply affine
|
|
transform.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters, initialized the same way as done for batch normalization.
|
|
Default: ``False``.
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, L)`
|
|
- Output: :math:`(N, C, L)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.InstanceNorm1d(100)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.InstanceNorm1d(100, affine=True)
|
|
>>> input = torch.randn(20, 100, 40)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() == 2:
|
|
raise ValueError(
|
|
'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
|
|
'This is because InstanceNorm1d reshapes inputs to'
|
|
'(1, N * C, ...) from (N, C,...) and this makes'
|
|
'variances 0.'
|
|
)
|
|
if input.dim() != 3:
|
|
raise ValueError('expected 3D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class InstanceNorm2d(_InstanceNorm):
|
|
r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Instance Normalization: The Missing Ingredient for Fast Stylization
|
|
<https://arxiv.org/abs/1607.08022>`__.
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension separately
|
|
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
|
|
The standard-deviation is calculated via the biased estimator, equivalent to
|
|
`torch.var(input, unbiased=False)`.
|
|
|
|
By default, this layer uses instance statistics computed from input data in
|
|
both training and evaluation modes.
|
|
|
|
If :attr:`track_running_stats` is set to ``True``, during training this
|
|
layer keeps running estimates of its computed mean and variance, which are
|
|
then used for normalization during evaluation. The running estimates are
|
|
kept with a default :attr:`momentum` of 0.1.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
.. note::
|
|
:class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
|
|
have some subtle differences. :class:`InstanceNorm2d` is applied
|
|
on each channel of channeled data like RGB images, but
|
|
:class:`LayerNorm` is usually applied on entire sample and often in NLP
|
|
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
|
|
transform, while :class:`InstanceNorm2d` usually don't apply affine
|
|
transform.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters, initialized the same way as done for batch normalization.
|
|
Default: ``False``.
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.InstanceNorm2d(100)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.InstanceNorm2d(100, affine=True)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError('expected 4D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class InstanceNorm3d(_InstanceNorm):
|
|
r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Instance Normalization: The Missing Ingredient for Fast Stylization
|
|
<https://arxiv.org/abs/1607.08022>`__.
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension separately
|
|
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size C (where C is the input size) if :attr:`affine` is ``True``.
|
|
The standard-deviation is calculated via the biased estimator, equivalent to
|
|
`torch.var(input, unbiased=False)`.
|
|
|
|
By default, this layer uses instance statistics computed from input data in
|
|
both training and evaluation modes.
|
|
|
|
If :attr:`track_running_stats` is set to ``True``, during training this
|
|
layer keeps running estimates of its computed mean and variance, which are
|
|
then used for normalization during evaluation. The running estimates are
|
|
kept with a default :attr:`momentum` of 0.1.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
.. note::
|
|
:class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
|
|
have some subtle differences. :class:`InstanceNorm3d` is applied
|
|
on each channel of channeled data like 3D models with RGB color, but
|
|
:class:`LayerNorm` is usually applied on entire sample and often in NLP
|
|
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
|
|
transform, while :class:`InstanceNorm3d` usually don't apply affine
|
|
transform.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters, initialized the same way as done for batch normalization.
|
|
Default: ``False``.
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``False``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.InstanceNorm3d(100)
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.InstanceNorm3d(100, affine=True)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError('expected 5D input (got {}D input)'
|
|
.format(input.dim()))
|