mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Problems with SN and DP after #12671 : 1. in eval mode, `weight_orig` is not getting correct gradient #12737 . Fix: keep `v` vector around as a buffer and always calculate `W = W_orig / (u @ W_orig @ v)` even in eval. 2. in training mode, the `weight` buffer of the parallelized module is never updated, if someone touches `weight_orig` and/or `weight` and makes them not sharing storage. So in `eval` the weight used is wrong. Fix: Make `weight` not a buffer anymore and always calculate it as above. 3. #12671 changed SN to update `u` in-place to make DP work correctly, but then it breaks backward through two forwards (e.g., the common GAN loss `D(real) - D(fake)`) because the vectors needed to backprop the 1st forward is changed in the 2nd forward. Fix: This PR clones `u` and `v` before using them. To maintain BC, I added a hook interface for producing and loading state_dict. This is ugly and we should really have better interface for spectral_norm. But for the purpose to fix this issue, I make this patch. Even if we have a better interface, BC mechanism for legacy loading legacy state_dict still needs to be done. cc The controller you requested could not be found. crcrpar Pull Request resolved: https://github.com/pytorch/pytorch/pull/13350 Differential Revision: D12931044 Pulled By: SsnL fbshipit-source-id: 8be6f934eaa62414d76d2c644dedd7e1b7eb31ef
304 lines
13 KiB
Python
304 lines
13 KiB
Python
import torch
|
|
from .module import Module
|
|
from torch.nn.parameter import Parameter
|
|
from .. import functional as F
|
|
from .. import init
|
|
|
|
|
|
# TODO: check contiguous in THNN
|
|
# TODO: use separate backend functions?
|
|
class _BatchNorm(Module):
|
|
_version = 2
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True):
|
|
super(_BatchNorm, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
if self.affine:
|
|
self.weight = Parameter(torch.Tensor(num_features))
|
|
self.bias = Parameter(torch.Tensor(num_features))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
if self.track_running_stats:
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
self.register_buffer('running_var', torch.ones(num_features))
|
|
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
|
else:
|
|
self.register_parameter('running_mean', None)
|
|
self.register_parameter('running_var', None)
|
|
self.register_parameter('num_batches_tracked', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_running_stats(self):
|
|
if self.track_running_stats:
|
|
self.running_mean.zero_()
|
|
self.running_var.fill_(1)
|
|
self.num_batches_tracked.zero_()
|
|
|
|
def reset_parameters(self):
|
|
self.reset_running_stats()
|
|
if self.affine:
|
|
init.uniform_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def _check_input_dim(self, input):
|
|
raise NotImplementedError
|
|
|
|
def forward(self, input):
|
|
self._check_input_dim(input)
|
|
|
|
exponential_average_factor = 0.0
|
|
|
|
if self.training and self.track_running_stats:
|
|
self.num_batches_tracked += 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
return F.batch_norm(
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training or not self.track_running_stats,
|
|
exponential_average_factor, self.eps)
|
|
|
|
def extra_repr(self):
|
|
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
|
|
'track_running_stats={track_running_stats}'.format(**self.__dict__)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
version = local_metadata.get('version', None)
|
|
|
|
if (version is None or version < 2) and self.track_running_stats:
|
|
# at version 2: added num_batches_tracked buffer
|
|
# this should have a default value of 0
|
|
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
|
if num_batches_tracked_key not in state_dict:
|
|
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
|
|
|
|
super(_BatchNorm, self)._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
|
class BatchNorm1d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
|
inputs with optional additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
|
|
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm1d(100, affine=False)
|
|
>>> input = torch.randn(20, 100)
|
|
>>> output = m(input)
|
|
|
|
.. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
|
|
https://arxiv.org/abs/1502.03167
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 2 and input.dim() != 3:
|
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class BatchNorm2d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
|
|
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm2d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> output = m(input)
|
|
|
|
.. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
|
|
https://arxiv.org/abs/1502.03167
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 4:
|
|
raise ValueError('expected 4D input (got {}D input)'
|
|
.format(input.dim()))
|
|
|
|
|
|
class BatchNorm3d(_BatchNorm):
|
|
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
|
|
from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
|
|
|
Also by default, during training this layer keeps running estimates of its
|
|
computed mean and variance, which are then used for normalization during
|
|
evaluation. The running estimates are kept with a default :attr:`momentum`
|
|
of 0.1.
|
|
|
|
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
|
keep running estimates, and batch statistics are instead used during
|
|
evaluation time as well.
|
|
|
|
.. note::
|
|
This :attr:`momentum` argument is different from one used in optimizer
|
|
classes and the conventional notion of momentum. Mathematically, the
|
|
update rule for running statistics here is
|
|
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
|
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
|
new observed value.
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
|
|
or Spatio-temporal Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
affine: a boolean value that when set to ``True``, this module has
|
|
learnable affine parameters. Default: ``True``
|
|
track_running_stats: a boolean value that when set to ``True``, this
|
|
module tracks the running mean and variance, and when set to ``False``,
|
|
this module does not track such statistics and always uses batch
|
|
statistics in both training and eval modes. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
Examples::
|
|
|
|
>>> # With Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100)
|
|
>>> # Without Learnable Parameters
|
|
>>> m = nn.BatchNorm3d(100, affine=False)
|
|
>>> input = torch.randn(20, 100, 35, 45, 10)
|
|
>>> output = m(input)
|
|
|
|
.. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
|
|
https://arxiv.org/abs/1502.03167
|
|
"""
|
|
|
|
def _check_input_dim(self, input):
|
|
if input.dim() != 5:
|
|
raise ValueError('expected 5D input (got {}D input)'
|
|
.format(input.dim()))
|