mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: xref gh-32838, gh-34032 This is a major refactor of parts of the documentation to split it up using sphinx's `autosummary` feature which will build out `autofuction` and `autoclass` stub files and link to them. The end result is that the top module pages like torch.nn.rst and torch.rst are now more like table-of-contents to the actual single-class or single-function documentations pages. Along the way, I modified many of the docstrings to eliminate sphinx warnings when building. I think the only thing I changed from a non-documentation perspective is to add names to `__all__` when adding them to `globals()` in `torch.__init__.py` I do not know the CI system: are the documentation build artifacts available after the build, so reviewers can preview before merging? Pull Request resolved: https://github.com/pytorch/pytorch/pull/37419 Differential Revision: D21337640 Pulled By: ezyang fbshipit-source-id: d4ad198780c3ae7a96a9f22651e00ff2d31a0c0f
132 lines
5.2 KiB
Python
132 lines
5.2 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import torch
|
|
import torch.nn.quantized.functional
|
|
import torch.nn.intrinsic as nni
|
|
|
|
class BatchNorm2d(torch.nn.BatchNorm2d):
|
|
r"""Applies Quantized Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
|
|
<https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
|
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
|
|
Shape:
|
|
- Input: :math:`(N, C, H, W)`
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
Examples:
|
|
|
|
>>> m = nn.quantized.BatchNorm2d(100)
|
|
>>> input = torch.randn(20, 100, 35, 45)
|
|
>>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
|
super(BatchNorm2d, self).__init__(num_features)
|
|
self.eps = eps
|
|
self.scale = 1.0
|
|
self.zero_point = 0
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.batch_norm2d(input, self.weight, self.bias, self.running_mean,
|
|
self.running_var, self.eps, self.scale, self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedBatchNorm2d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
if type(mod) == nni.BNReLU2d:
|
|
activation_post_process = mod[1].activation_post_process
|
|
mod = mod[0]
|
|
else:
|
|
activation_post_process = mod.activation_post_process
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(mod.num_features, mod.eps)
|
|
new_mod.weight = mod.weight
|
|
new_mod.bias = mod.bias
|
|
new_mod.running_mean = mod.running_mean
|
|
new_mod.running_var = mod.running_var
|
|
new_mod.scale = float(scale)
|
|
new_mod.zero_point = int(zero_point)
|
|
return new_mod
|
|
|
|
class BatchNorm3d(torch.nn.BatchNorm3d):
|
|
r"""Applies Quantized Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
|
with additional channel dimension) as described in the paper
|
|
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
|
|
<https://arxiv.org/abs/1502.03167>`__ .
|
|
|
|
.. math::
|
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
|
on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
|
Args:
|
|
num_features: :math:`C` from an expected input of size
|
|
:math:`(N, C, D, H, W)`
|
|
eps: a value added to the denominator for numerical stability.
|
|
Default: 1e-5
|
|
momentum: the value used for the running_mean and running_var
|
|
computation. Can be set to ``None`` for cumulative moving average
|
|
(i.e. simple average). Default: 0.1
|
|
Shape:
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
Examples:
|
|
>>> m = nn.quantized.BatchNorm3d(100)
|
|
>>> input = torch.randn(20, 100, 25, 35, 45)
|
|
>>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
|
>>> output = m(input)
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
|
super(BatchNorm3d, self).__init__(num_features)
|
|
self.eps = eps
|
|
self.scale = 1.0
|
|
self.zero_point = 0
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean,
|
|
self.running_var, self.eps, self.scale, self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedBatchNorm3d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
if type(mod) == nni.BNReLU3d:
|
|
activation_post_process = mod[1].activation_post_process
|
|
mod = mod[0]
|
|
else:
|
|
activation_post_process = mod.activation_post_process
|
|
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
new_mod = cls(mod.num_features, mod.eps)
|
|
new_mod.weight = mod.weight
|
|
new_mod.bias = mod.bias
|
|
new_mod.running_mean = mod.running_mean
|
|
new_mod.running_var = mod.running_var
|
|
new_mod.scale = float(scale)
|
|
new_mod.zero_point = int(zero_point)
|
|
return new_mod
|