pytorch/torch/ao/nn/quantized/modules/utils.py
zaf c92e5ac95b [quant][ao_migration] torch.nn.quantized.modulestorch.ao.nn.quantized.modules (#78713)
Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.

The list of the `nn.quantized` files that are being migrated:

- [ ] `torch.nn.quantized` → `torch.ao.nn.quantized`
    - [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
    - [X] [Current PR] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
    - [ ] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
    - [ ] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [ ] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [ ] `torch.nn.qat` → `torch.ao.nn.qat`
    - [ ] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
    - [ ] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
    - [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
    - [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
    - [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
        - [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
        - [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`

Majority of the files are just moved to the new location.
However, specific files need to be double checked:

- Documentation @vkuzo
  - docs/source/conf.py
  - docs/source/quantization.rst
- [quantize_fx](torch/ao/quantization/quantize_fx.py) @jerryzh168
- [common test routine](test/quantization/ao_migration/common.py) @HDCharles
- JIT stuff @jamesr66a
  - torch/csrc/jit/passes/hoist_conv_packed_params.cpp
  - torch/csrc/jit/passes/quantization/helper.h
  - torch/csrc/jit/serialization/import_source.cpp

Differential Revision: [D38926012](https://our.internmc.facebook.com/intern/diff/D38926012/)

Differential Revision: [D38926012](https://our.internmc.facebook.com/intern/diff/D38926012)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78713
Approved by: https://github.com/jerryzh168
2022-08-25 16:50:33 +00:00

114 lines
4.4 KiB
Python

import abc
import torch
import itertools
import collections
from torch.nn.modules.module import _addindent
class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta):
"""Wrapper for quantized modules than can be lowered from reference modules."""
@classmethod
@abc.abstractmethod
def from_reference(cls, ref_module, output_scale, output_zero_point):
raise NotImplementedError
def _get_weight_observer(observer):
# FakeQuantize observer
if hasattr(observer, "activation_post_process"):
observer = observer.activation_post_process
# UniformQuantizationObserverBase observer
return observer
def _needs_weight_clamping(observer, dtype):
observer = _get_weight_observer(observer)
if dtype in [torch.qint8, torch.quint8, torch.qint32]:
info = torch.iinfo(dtype)
return observer.quant_min > info.min or observer.quant_max < info.max
return False
def _clamp_weights(qweight, observer, scale, zp):
if not _needs_weight_clamping(observer, qweight.dtype):
return qweight
observer = _get_weight_observer(observer)
min_, max_ = observer.quant_min, observer.quant_max
# Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet.
qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
qw_int_min = torch.clone(qweight.int_repr()).fill_(min_)
qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max)
if observer.qscheme in [torch.per_tensor_symmetric,
torch.per_tensor_affine]:
qweight = torch._make_per_tensor_quantized_tensor(qw_int, scale.item(), zp.item())
elif observer.qscheme in [torch.per_channel_symmetric,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams]:
qweight = torch._make_per_channel_quantized_tensor(qw_int, scale, zp, axis=observer.ch_axis)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _quantize_weight(float_wt, observer):
wt_scale, wt_zp = observer.calculate_qparams()
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch.quantize_per_tensor(
float_wt,
float(wt_scale), int(wt_zp), torch.qint8)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
wt_axis = observer.ch_axis
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, observer.dtype)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _ntuple_from_first(n):
"""Converts the argument to a tuple of size n
with the first element repeated."""
def parse(x):
while isinstance(x, collections.abc.Sequence):
if len(x) == n:
break
x = x[0]
return tuple(itertools.repeat(x, n))
return parse
def hide_packed_params_repr(self, params):
# We don't want to show `PackedParams` children, hence custom
# `__repr__`. This is the same as nn.Module.__repr__, except the check
# for the `params module`.
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
if isinstance(module, params):
continue
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
_pair_from_first = _ntuple_from_first(2)