pytorch/torch/nn/quantized/modules/batchnorm.py
mattip f10fbcc820 Split up documentation into subpages and clean up some warnings (#37419)
Summary:
xref gh-32838, gh-34032

This is a major refactor of parts of the documentation to split it up using sphinx's `autosummary` feature which will build out `autofuction` and `autoclass` stub files and link to them. The end result is that the top module pages like torch.nn.rst and torch.rst are now more like table-of-contents to the actual single-class or single-function documentations pages.

Along the way, I modified many of the docstrings to eliminate sphinx warnings when building. I think the only thing I changed from a non-documentation perspective is to add names to `__all__` when adding them to `globals()` in `torch.__init__.py`

I do not know the CI system: are the documentation build artifacts available after the build, so reviewers can preview before merging?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37419

Differential Revision: D21337640

Pulled By: ezyang

fbshipit-source-id: d4ad198780c3ae7a96a9f22651e00ff2d31a0c0f
2020-05-04 09:39:22 -07:00

132 lines
5.2 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
import torch.nn.quantized.functional
import torch.nn.intrinsic as nni
class BatchNorm2d(torch.nn.BatchNorm2d):
r"""Applies Quantized Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
<https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> m = nn.quantized.BatchNorm2d(100)
>>> input = torch.randn(20, 100, 35, 45)
>>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BatchNorm2d, self).__init__(num_features)
self.eps = eps
self.scale = 1.0
self.zero_point = 0
def forward(self, input):
return torch.ops.quantized.batch_norm2d(input, self.weight, self.bias, self.running_mean,
self.running_var, self.eps, self.scale, self.zero_point)
def _get_name(self):
return 'QuantizedBatchNorm2d'
@classmethod
def from_float(cls, mod):
if type(mod) == nni.BNReLU2d:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.running_mean = mod.running_mean
new_mod.running_var = mod.running_var
new_mod.scale = float(scale)
new_mod.zero_point = int(zero_point)
return new_mod
class BatchNorm3d(torch.nn.BatchNorm3d):
r"""Applies Quantized Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
<https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> m = nn.quantized.BatchNorm3d(100)
>>> input = torch.randn(20, 100, 25, 35, 45)
>>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BatchNorm3d, self).__init__(num_features)
self.eps = eps
self.scale = 1.0
self.zero_point = 0
def forward(self, input):
return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean,
self.running_var, self.eps, self.scale, self.zero_point)
def _get_name(self):
return 'QuantizedBatchNorm3d'
@classmethod
def from_float(cls, mod):
if type(mod) == nni.BNReLU3d:
activation_post_process = mod[1].activation_post_process
mod = mod[0]
else:
activation_post_process = mod.activation_post_process
scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.running_mean = mod.running_mean
new_mod.running_var = mod.running_var
new_mod.scale = float(scale)
new_mod.zero_point = int(zero_point)
return new_mod