mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This PR did three things: 1. It export the BatchNorm functional and module, and rewrite some of the components to stay align with the current supported JIT features 2. In the process of export, add necessary compiler support for in_place op aug assign 4. change the test_jit behavior in add_module_test to utilize a single rng state during module initialization Pull Request resolved: https://github.com/pytorch/pytorch/pull/14016 Differential Revision: D13112064 Pulled By: wanchaol fbshipit-source-id: 31e3aee5fbb509673c781e7dbb6d8884cfa55d91
313 lines
14 KiB
Python
313 lines
14 KiB
Python
from __future__ import division
|
|
|
|
import torch
|
|
from .module import Module
|
|
from torch.nn.parameter import Parameter
|
|
from .. import functional as F
|
|
from .. import init
|
|
from ..._jit_internal import weak_module, weak_script_method
|
|
|
|
|
|
# TODO: check contiguous in THNN
|
|
# TODO: use separate backend functions?
|
|
@weak_module
|
|
class _BatchNorm(Module):
|
|
_version = 2
|
|
__constants__ = ['training', 'track_running_stats', 'momentum', 'eps']
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True):
|
|
super(_BatchNorm, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
if self.affine:
|
|
self.weight = Parameter(torch.Tensor(num_features))
|
|
self.bias = Parameter(torch.Tensor(num_features))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
if self.track_running_stats:
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
self.register_buffer('running_var', torch.ones(num_features))
|
|
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
|
else:
|
|
self.register_parameter('running_mean', None)
|
|
self.register_parameter('running_var', None)
|
|
self.register_parameter('num_batches_tracked', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_running_stats(self):
|
|
if self.track_running_stats:
|
|
self.running_mean.zero_()
|
|
self.running_var.fill_(1)
|
|
self.num_batches_tracked.zero_()
|
|
|
|
def reset_parameters(self):
|
|
self.reset_running_stats()
|
|
if self.affine:
|
|
init.uniform_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def _check_input_dim(self, input):
|
|
raise NotImplementedError
|
|
|
|
@weak_script_method
|
|
def forward(self, input):
|
|
self._check_input_dim(input)
|
|
|
|
exponential_average_factor = 0.0
|
|
|
|
if self.training and self.track_running_stats:
|
|
self.num_batches_tracked += 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
return F.batch_norm(
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training or not self.track_running_stats,
|
|
exponential_average_factor, self.eps)
|
|
|
|
def extra_repr(self):
|
|
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
|
|
'track_running_stats={track_running_stats}'.format(**self.__dict__)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
|
|
if (version is None or version < 2) and self.track_running_stats:
|
|
# at version 2: added num_batches_tracked buffer
|
|
# this should have a default value of 0
|
|
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
|
if num_batches_tracked_key not in state_dict:
|
|
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
|
|
|
|
super(_BatchNorm, self)._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
|
@weak_module
|
|
class BatchNorm1d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
|
inputs with optional additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
|
|
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100, affine=False)
|
|
>>> input = torch.randn(20, 100)
|
|
>>> output = m(input)
|
|
|
|
.. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
|
|
https://arxiv.org/abs/1502.03167
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 2 and input.dim() != 3:
|
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
@weak_module
|
|
class BatchNorm2d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
|
|
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> output = m(input)
|
|
|
|
.. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
|
|
https://arxiv.org/abs/1502.03167
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError('expected 4D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
@weak_module
|
|
class BatchNorm3d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
|
|
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
|
|
or Spatio-temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
|
|
.. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
|
|
https://arxiv.org/abs/1502.03167
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError('expected 5D input (got {}D input)'
|
|
.format(input.dim()))
|