mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch.quantization conversion utilities, observers for eager mode quantization (#22010)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22010 torch.quantization module with observers and conversion routines Reviewed By: zafartahirov Differential Revision: D15554183 fbshipit-source-id: 05a3fabe28dd701978b8ecebf5bfc3a4c044ba5c
This commit is contained in:
parent
073fa6f411
commit
5040d52a5a
381
test/test_quantization.py
Normal file
381
test/test_quantization.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import torch
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.quantization as tq
|
||||
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
|
||||
default_eval_fn, QConfig, default_qconfig, default_observer, quantize, \
|
||||
prepare, convert
|
||||
|
||||
from common_utils import TestCase, run_tests
|
||||
|
||||
class SingleLayerLinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(SingleLayerLinearModel, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
class TwoLayerLinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TwoLayerLinearModel, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
|
||||
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class LinearReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(LinearReluModel, self).__init__()
|
||||
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.fc(x))
|
||||
return x
|
||||
|
||||
class NestedModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(NestedModel, self).__init__()
|
||||
self.sub1 = LinearReluModel()
|
||||
self.sub2 = TwoLayerLinearModel()
|
||||
self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sub1(x)
|
||||
x = self.sub2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
class InnerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(InnerModule, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.fc2(self.relu(self.fc1(x))))
|
||||
|
||||
class WrappedModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(WrappedModel, self).__init__()
|
||||
self.qconfig = default_qconfig
|
||||
self.sub = QuantWrapper(InnerModule())
|
||||
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
# don't quantize this fc
|
||||
self.fc.qconfig = None
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(self.sub(x))
|
||||
|
||||
class ManualQuantModel(torch.nn.Module):
|
||||
r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ManualQuantModel, self).__init__()
|
||||
self.qconfig = default_qconfig
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.fc(x)
|
||||
return self.dequant(x)
|
||||
|
||||
calib_data = [torch.rand(20, 5, dtype=torch.float) for _ in range(20)]
|
||||
|
||||
class ModelQuantizeAPITest(TestCase):
|
||||
|
||||
def checkNoPrepModules(self, module):
|
||||
r"""Checks the module does not contain child
|
||||
modules for quantization prepration, e.g.
|
||||
quant, dequant and observer
|
||||
"""
|
||||
self.assertFalse(hasattr(module, 'quant'))
|
||||
self.assertFalse(hasattr(module, 'dequant'))
|
||||
|
||||
def checkHasPrepModules(self, module):
|
||||
r"""Checks the module contains child
|
||||
modules for quantization prepration, e.g.
|
||||
quant, dequant and observer
|
||||
"""
|
||||
self.assertTrue(hasattr(module, 'module'))
|
||||
self.assertTrue(hasattr(module, 'quant'))
|
||||
self.assertTrue(hasattr(module, 'dequant'))
|
||||
|
||||
def checkObservers(self, module):
|
||||
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
|
||||
self.assertTrue(hasattr(module, 'observer'))
|
||||
for child in module.children():
|
||||
self.checkObservers(child)
|
||||
|
||||
def checkQuantDequant(self, mod):
|
||||
self.assertEqual(type(mod.quant), nnq.Quantize)
|
||||
self.assertEqual(type(mod.dequant), nnq.DeQuantize)
|
||||
|
||||
def checkQuantizedLinear(self, mod):
|
||||
self.assertEqual(type(mod.module), nnq.Linear)
|
||||
self.assertEqual(mod.module.bias.dtype, torch.qint32)
|
||||
self.checkQuantDequant(mod)
|
||||
|
||||
def checkLinear(self, mod):
|
||||
self.assertEqual(type(mod), torch.nn.Linear)
|
||||
|
||||
def test_single_layer(self):
|
||||
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
|
||||
to nnq.Linear which is the quantized version of the module
|
||||
"""
|
||||
model = SingleLayerLinearModel()
|
||||
qconfig_dict = {
|
||||
'': default_qconfig
|
||||
}
|
||||
model = prepare(model, qconfig_dict)
|
||||
# Check if observers and quant/dequant nodes are inserted
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkHasPrepModules(model.fc1)
|
||||
self.checkObservers(model)
|
||||
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkHasPrepModules(model.fc1)
|
||||
self.checkQuantizedLinear(model.fc1)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(SingleLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_two_layers(self):
|
||||
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
|
||||
`fc2`, and `fc1`is not quantized
|
||||
"""
|
||||
model = TwoLayerLinearModel()
|
||||
qconfig_dict = {
|
||||
'fc2': default_qconfig
|
||||
}
|
||||
model = prepare(model, qconfig_dict)
|
||||
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkObservers(model)
|
||||
self.checkNoPrepModules(model.fc1)
|
||||
self.checkHasPrepModules(model.fc2)
|
||||
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkNoPrepModules(model.fc1)
|
||||
self.checkHasPrepModules(model.fc2)
|
||||
self.assertEqual(type(model.fc1), torch.nn.Linear)
|
||||
self.checkQuantizedLinear(model.fc2)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(TwoLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_nested1(self):
|
||||
r"""Test quantization for nested model, top level 'fc3' and
|
||||
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
|
||||
"""
|
||||
model = NestedModel()
|
||||
qconfig_dict = {
|
||||
'fc3': default_qconfig,
|
||||
'sub2.fc1': default_qconfig
|
||||
}
|
||||
|
||||
def checkPrepModules(model, before_calib=False):
|
||||
if before_calib:
|
||||
self.checkObservers(model)
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkNoPrepModules(model.sub1)
|
||||
self.checkNoPrepModules(model.sub1.fc)
|
||||
self.checkNoPrepModules(model.sub1.relu)
|
||||
self.checkNoPrepModules(model.sub2)
|
||||
self.checkHasPrepModules(model.sub2.fc1)
|
||||
self.checkNoPrepModules(model.sub2.fc2)
|
||||
self.checkHasPrepModules(model.fc3)
|
||||
|
||||
model = prepare(model, qconfig_dict)
|
||||
checkPrepModules(model, True)
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
checkPrepModules(model)
|
||||
self.checkLinear(model.sub1.fc)
|
||||
self.checkQuantizedLinear(model.fc3)
|
||||
self.checkQuantizedLinear(model.sub2.fc1)
|
||||
self.checkLinear(model.sub2.fc2)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||
checkQuantized(model)
|
||||
|
||||
|
||||
def test_nested2(self):
|
||||
r"""Another test case for quantized, we will quantize all submodules
|
||||
of submodule sub2, this will include redundant quant/dequant, to
|
||||
remove them we need to manually call QuantWrapper or insert
|
||||
QuantStub/DeQuantStub, see `test_quant_dequant_wrapper` and
|
||||
`test_manual`
|
||||
"""
|
||||
model = NestedModel()
|
||||
qconfig_dict = {
|
||||
'fc3': default_qconfig,
|
||||
'sub2': default_qconfig
|
||||
}
|
||||
model = prepare(model, qconfig_dict)
|
||||
|
||||
def checkPrepModules(model, before_calib=False):
|
||||
if before_calib:
|
||||
self.checkObservers(model)
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkNoPrepModules(model.sub1)
|
||||
self.checkNoPrepModules(model.sub1.fc)
|
||||
self.checkNoPrepModules(model.sub1.relu)
|
||||
self.checkNoPrepModules(model.sub2)
|
||||
self.checkHasPrepModules(model.sub2.fc1)
|
||||
self.checkHasPrepModules(model.sub2.fc2)
|
||||
self.checkHasPrepModules(model.fc3)
|
||||
|
||||
checkPrepModules(model, True)
|
||||
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
checkPrepModules(model)
|
||||
self.checkLinear(model.sub1.fc)
|
||||
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
|
||||
self.checkQuantizedLinear(model.sub2.fc1)
|
||||
self.checkQuantizedLinear(model.sub2.fc2)
|
||||
self.checkQuantizedLinear(model.fc3)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_nested3(self):
|
||||
r"""More complicated nested test case with child qconfig overrides
|
||||
parent qconfig
|
||||
"""
|
||||
model = NestedModel()
|
||||
custum_options = {
|
||||
'dtype': torch.quint8,
|
||||
'qscheme': torch.per_tensor_affine
|
||||
}
|
||||
custom_qconfig = QConfig(weight=default_observer(),
|
||||
activation=default_observer(**custum_options))
|
||||
qconfig_dict = {
|
||||
'fc3': default_qconfig,
|
||||
'sub2': default_qconfig,
|
||||
'sub2.fc1': custom_qconfig
|
||||
}
|
||||
model = prepare(model, qconfig_dict)
|
||||
|
||||
def checkPrepModules(model, before_calib=False):
|
||||
if before_calib:
|
||||
self.checkObservers(model)
|
||||
self.checkNoPrepModules(model)
|
||||
self.checkNoPrepModules(model.sub1)
|
||||
self.checkNoPrepModules(model.sub1.fc)
|
||||
self.checkNoPrepModules(model.sub1.relu)
|
||||
self.checkNoPrepModules(model.sub2)
|
||||
self.checkHasPrepModules(model.sub2.fc1)
|
||||
self.checkHasPrepModules(model.sub2.fc2)
|
||||
self.checkHasPrepModules(model.fc3)
|
||||
|
||||
checkPrepModules(model, True)
|
||||
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
checkPrepModules(model)
|
||||
self.checkQuantizedLinear(model.sub2.fc1)
|
||||
self.checkQuantizedLinear(model.sub2.fc2)
|
||||
self.checkQuantizedLinear(model.fc3)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_quant_wrapper(self):
|
||||
r"""User need to modify the original code with QuantWrapper,
|
||||
and call the quantization utility functions.
|
||||
"""
|
||||
model = WrappedModel()
|
||||
|
||||
# since we didn't provide qconfig_dict, the model is modified inplace
|
||||
# but we can do `model = prepare(model)` as well
|
||||
prepare(model)
|
||||
self.checkObservers(model)
|
||||
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
self.checkLinear(model.fc)
|
||||
self.checkQuantDequant(model.sub)
|
||||
self.assertEqual(type(model.sub.module.fc1), nnq.Linear)
|
||||
self.assertEqual(type(model.sub.module.fc2), nnq.Linear)
|
||||
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(WrappedModel(), default_eval_fn, calib_data, {})
|
||||
checkQuantized(model)
|
||||
|
||||
|
||||
def test_manual(self):
|
||||
r"""User inserts QuantStub and DeQuantStub in model code
|
||||
and call the quantization utility functions.
|
||||
"""
|
||||
model = ManualQuantModel()
|
||||
# propagate the qconfig of parents to children, model is changed
|
||||
# inplace
|
||||
prepare(model)
|
||||
self.checkObservers(model)
|
||||
|
||||
default_eval_fn(model, calib_data)
|
||||
convert(model)
|
||||
|
||||
def checkQuantized(model):
|
||||
self.assertEqual(type(model.fc), nnq.Linear)
|
||||
default_eval_fn(model, calib_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
# test one line API
|
||||
model = quantize(ManualQuantModel(), default_eval_fn, calib_data)
|
||||
checkQuantized(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -4,9 +4,9 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from .. import functional as F
|
||||
from ...modules.module import Module
|
||||
from ...modules.activation import ReLU as NNReLU
|
||||
|
||||
class ReLU(Module):
|
||||
class ReLU(NNReLU):
|
||||
r"""Applies quantized rectified linear unit function element-wise:
|
||||
|
||||
:math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point.
|
||||
|
|
|
|||
|
|
@ -35,7 +35,9 @@ class Quantize(Module):
|
|||
|
||||
@staticmethod
|
||||
def from_float(mod):
|
||||
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
|
||||
assert hasattr(mod, 'observer')
|
||||
qparams = mod.observer.calculate_qparams()
|
||||
return Quantize(qparams[0].item(), qparams[1].item(), mod.observer.dtype)
|
||||
|
||||
class DeQuantize(Module):
|
||||
r"""Dequantizes an incoming tensor
|
||||
|
|
@ -136,3 +138,30 @@ class Linear(NNLinear):
|
|||
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
return
|
||||
|
||||
# TODO: support initializing from quantization parameters when Quantizer is
|
||||
# exposed in python
|
||||
@staticmethod
|
||||
def from_float(mod):
|
||||
r"""Create a quantized module from a float module or qparams_dict
|
||||
|
||||
Args: `mod` a float module, either produced by torch.quantization utilities
|
||||
or directly from user
|
||||
"""
|
||||
assert type(mod) == NNLinear, 'nnq.Linear.from_float only works for nn.Linear'
|
||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
assert hasattr(mod, 'observer'), 'Input float module must have observer attached'
|
||||
activation_observer = mod.observer
|
||||
act_qparams = activation_observer.calculate_qparams()
|
||||
weight_observer = mod.qconfig.weight()
|
||||
weight_observer(mod.weight)
|
||||
wt_qparams = weight_observer.calculate_qparams()
|
||||
bias_scale = (wt_qparams[0] * act_qparams[0]).float()
|
||||
qweight = torch.quantize_linear(mod.weight.float(), wt_qparams[0], wt_qparams[1].long().item(), torch.qint8)
|
||||
qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
|
||||
qlinear = Linear(mod.in_features, mod.out_features)
|
||||
qlinear._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(qweight)
|
||||
qlinear.bias = qbias
|
||||
qlinear.out_scale = torch.tensor([act_qparams[0]])
|
||||
qlinear.out_zero_point = torch.tensor([act_qparams[1]])
|
||||
return qlinear
|
||||
|
|
|
|||
9
torch/quantization/QConfig.py
Normal file
9
torch/quantization/QConfig.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
from collections import namedtuple
|
||||
from .observer import *
|
||||
|
||||
QConfig = namedtuple('QConfig',
|
||||
['weight', 'activation'])
|
||||
|
||||
default_qconfig = QConfig(default_weight_observer(),
|
||||
default_observer())
|
||||
28
torch/quantization/__init__.py
Normal file
28
torch/quantization/__init__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
from .quantize import * # noqa: F401
|
||||
from .observer import * # noqa: F401
|
||||
from .QConfig import * # noqa: F401
|
||||
|
||||
def default_eval_fn(model, calib_data):
|
||||
r"""
|
||||
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
||||
input Tensors and run the model on the dataset
|
||||
"""
|
||||
for data in calib_data:
|
||||
model(data)
|
||||
|
||||
_all__ = [
|
||||
'QuantWrapper', 'QuantStub', 'DeQuantStub', 'DEFAULT_MODULE_MAPPING',
|
||||
# Top level API for quantizing a float model
|
||||
'quantize',
|
||||
# Sub functions called by quantize
|
||||
'prepare', 'convert',
|
||||
# Sub functions for `prepare` and `swap_module`
|
||||
'propagate_qconfig', 'add_quant_dequant', 'add_observer', 'swap_module',
|
||||
'default_eval_fn',
|
||||
# Observers
|
||||
'Observer', 'WeightObserver', 'observer', 'default_observer',
|
||||
'default_weight_observer',
|
||||
# QConfig
|
||||
'QConfig', 'default_qconfig'
|
||||
]
|
||||
73
torch/quantization/observer.py
Normal file
73
torch/quantization/observer.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from functools import partial
|
||||
|
||||
class Observer(nn.Module):
|
||||
r"""Default Observer Module
|
||||
A default implementation of the observer module, only works for
|
||||
`per_tensor_affine` quantization scheme.
|
||||
The module will record the running average of max and min value of the
|
||||
observed Tensor and calulate_qparams will calculate the scale and zero_point
|
||||
|
||||
Other types of Observers should follow the same API, it can take arbitrary
|
||||
number of keyward arguments. In forward, it will update the statistics of
|
||||
the observed Tensor. And it should provide a `calculate_qparam` function
|
||||
that computes the quantization parameters given the collected statistics.
|
||||
TODO: Maybe add an abstract Observer class that enforces these rules?
|
||||
"""
|
||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine):
|
||||
super(Observer, self).__init__()
|
||||
self.dtype = dtype
|
||||
self.qscheme = qscheme
|
||||
assert self.qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric), \
|
||||
'Default Observer only works for per_tensor_affine and \
|
||||
per_tensor_symmetric quantization scheme'
|
||||
assert self.dtype in (torch.qint8, torch.quint8), \
|
||||
'Default Observer only works for qint8 and quint data type'
|
||||
self.min_val = None
|
||||
self.max_val = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.min_val is None or self.max_val is None:
|
||||
self.min_val = torch.min(x)
|
||||
self.max_val = torch.max(x)
|
||||
else:
|
||||
self.min_val = torch.min(torch.min(x), self.min_val)
|
||||
self.max_val = torch.max(torch.max(x), self.max_val)
|
||||
|
||||
def calculate_qparams(self):
|
||||
if self.dtype == torch.qint8:
|
||||
qmin, qmax = -128, 127
|
||||
else:
|
||||
qmin, qmax = 0, 255
|
||||
n_levels = 255.0
|
||||
if self.max_val is None or self.min_val is None:
|
||||
raise Exception('must run observer before calling calculate_qparams!')
|
||||
max_val, min_val = self.max_val.item(), self.min_val.item()
|
||||
if max_val == min_val:
|
||||
scale = 1.0
|
||||
zero_point = 0
|
||||
else:
|
||||
if self.qscheme == torch.per_tensor_symmetric:
|
||||
max_val = max(-min_val, max_val)
|
||||
scale = max_val / 127.0
|
||||
zero_point = 0 if self.dtype == torch.qint8 else 128
|
||||
else:
|
||||
scale = (max_val - min_val) / n_levels
|
||||
zero_point = qmin - round(min_val / scale)
|
||||
zero_point = max(qmin, zero_point)
|
||||
zero_point = min(qmax, zero_point)
|
||||
|
||||
return torch.tensor([scale, zero_point])
|
||||
|
||||
def observer(observer_cls, **kwargs):
|
||||
return partial(observer_cls, **kwargs)
|
||||
|
||||
def default_observer(**kwargs):
|
||||
return observer(Observer, **kwargs)
|
||||
|
||||
def default_weight_observer(**kwargs):
|
||||
kwargs.setdefault('dtype', torch.qint8)
|
||||
kwargs.setdefault('qscheme', torch.per_tensor_symmetric)
|
||||
return observer(Observer, **kwargs)
|
||||
255
torch/quantization/quantize.py
Normal file
255
torch/quantization/quantize.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized as nnq
|
||||
import torch
|
||||
|
||||
def propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix=''):
|
||||
r"""This is a helper function for `propagate_qconfig`
|
||||
|
||||
Args:
|
||||
module: input module
|
||||
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||
configuration
|
||||
qconfig_parent: quantization config of parent module, we will fallback to
|
||||
this config when there is no specified config for current
|
||||
module
|
||||
prefix: corresponding prefix of the current module, used as key in
|
||||
qconfig_dict
|
||||
|
||||
Return:
|
||||
None, module is modified inplace with qconfig attached
|
||||
"""
|
||||
if not hasattr(module, 'qconfig'):
|
||||
module.qconfig = None
|
||||
if qconfig_dict and prefix in qconfig_dict:
|
||||
module.qconfig = qconfig_dict[prefix]
|
||||
else:
|
||||
module.qconfig = qconfig_parent
|
||||
print('prefix:', prefix, 'qconfig: ', module.qconfig)
|
||||
|
||||
for name, child in module.named_children():
|
||||
module_prefix = prefix + '.' + name if prefix else name
|
||||
propagate_qconfig_helper(child, qconfig_dict, module.qconfig, module_prefix)
|
||||
|
||||
def propagate_qconfig(module, qconfig_dict=None):
|
||||
r"""Propagate qconfig through the module hierarchy and assign `qconfig`
|
||||
attribute on each leaf module
|
||||
|
||||
Args:
|
||||
module: input module
|
||||
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||
configuration, qconfig applies to all submodules of a given
|
||||
module unless qconfig for the submodules are specified(when the
|
||||
submodule already has qconfig attribute)
|
||||
|
||||
Return:
|
||||
None, module is modified inplace with qconfig attached
|
||||
"""
|
||||
if qconfig_dict is None:
|
||||
qconfig_dict = {}
|
||||
propagate_qconfig_helper(module, qconfig_dict)
|
||||
|
||||
def _observer_forward_hook(self, input, output):
|
||||
r"""Forward hook that calls observer on the output
|
||||
"""
|
||||
self.observer(output)
|
||||
|
||||
# TODO(jerryzh): remove_observer?
|
||||
def add_observer(module):
|
||||
r"""Add observer for the leaf child of the module.
|
||||
|
||||
This function insert observer module to all leaf child module that
|
||||
has a valid qconfig attribute.
|
||||
|
||||
Args:
|
||||
module: input module with qconfig attributes for all the leaf modules
|
||||
that we want to quantize
|
||||
|
||||
Return:
|
||||
None, module is modified inplace with added observer modules and
|
||||
forward_hooks
|
||||
"""
|
||||
for child in module.children():
|
||||
add_observer(child)
|
||||
|
||||
# Insert observers only for leaf nodes, note that this observer is for
|
||||
# the output of the module, for input QuantStub will observe them
|
||||
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
|
||||
# observer and hook will be gone after we swap the module
|
||||
module.add_module('observer', module.qconfig.activation())
|
||||
module.register_forward_hook(_observer_forward_hook)
|
||||
|
||||
class QuantWrapper(nn.Module):
|
||||
r"""A wrapper class that wraps the input module, adds QuantStub and
|
||||
DeQuantStub and surround the call to module with call to quant and dequant
|
||||
modules.
|
||||
|
||||
This is used by the `quantization` utility functions to add the quant and
|
||||
dequant modules, before `convert` function `QuantStub` will just be observer,
|
||||
it observes the input tensor, after `convert`, `QuantStub`
|
||||
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
|
||||
for `DeQuantStub`.
|
||||
"""
|
||||
def __init__(self, module):
|
||||
super(QuantWrapper, self).__init__()
|
||||
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
|
||||
self.quant = QuantStub(qconfig)
|
||||
self.dequant = DeQuantStub()
|
||||
self.module = module
|
||||
|
||||
def forward(self, X):
|
||||
X = self.quant(X)
|
||||
X = self.module(X)
|
||||
return self.dequant(X)
|
||||
|
||||
def add_quant_dequant(module):
|
||||
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
|
||||
Note that this function will modify the children of module inplace and it
|
||||
can return a new module which wraps the input module as well.
|
||||
|
||||
Args:
|
||||
module: input module with qconfig attributes for all the leaf modules
|
||||
that we want to quantize
|
||||
|
||||
Return:
|
||||
Either the inplace modified module with submodules wrapped in
|
||||
`QuantWrapper` based on qconfig or a new `QuantWrapper` module which
|
||||
wraps the input module, the latter case only happens when the input
|
||||
module is a leaf module and we want to quantize it.
|
||||
"""
|
||||
if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
|
||||
return QuantWrapper(module)
|
||||
|
||||
for name, child in module.named_children():
|
||||
module._modules[name] = add_quant_dequant(child)
|
||||
return module
|
||||
|
||||
def prepare(module, qconfig_dict=None):
|
||||
r"""Prepares the module for calibration or training given a qconfig_dict.
|
||||
Note that the module will be modified inplace but in case the input module
|
||||
is a leaf module, a wrapped module will be returned.
|
||||
|
||||
Args:
|
||||
mod: input module
|
||||
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||
configuration
|
||||
Return:
|
||||
A module with qconfig propogated, observer and quant dequant or fake
|
||||
quant modules attached, a module that is ready for calibration or
|
||||
training
|
||||
"""
|
||||
propagate_qconfig(module, qconfig_dict)
|
||||
if qconfig_dict:
|
||||
module = add_quant_dequant(module)
|
||||
add_observer(module)
|
||||
return module
|
||||
|
||||
class QuantStub(nn.Module):
|
||||
r"""Quantize stub module, before calibration, this is same as an observer,
|
||||
it will be swapped as `nnq.Quantize` in `convert`.
|
||||
|
||||
Args:
|
||||
qconfig: quantization configuration for the tensor,
|
||||
if qconfig is not provided, we will get qconfig from parent modules
|
||||
"""
|
||||
def __init__(self, qconfig=None):
|
||||
super(QuantStub, self).__init__()
|
||||
if qconfig:
|
||||
self.qconfig = qconfig
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class DeQuantStub(nn.Module):
|
||||
r"""Dequantize stub module, before calibration, this is same as identity,
|
||||
this will be swapped as `nnq.DeQuantize` in `convert`.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(DeQuantStub, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
def quantize(module, eval_fn, eval_args, qconfig_dict=None):
|
||||
r"""Converts a float module to quantized module.
|
||||
|
||||
First it will prepare the module for calibration or training, then it calls
|
||||
`eval_fn` which will run the calibration step or training step,
|
||||
after that we will call `convert` which will convert the module to a
|
||||
quantized module.
|
||||
|
||||
When `qconfig_dict` is None or empty dictionary, we will assume user will
|
||||
insert quant/dequant stubs and add qconfig in approporiate places.
|
||||
When `qconfig_dict` is not None or empty dictionary, we will add quant/dequant
|
||||
stubs using QuantWrapper for all the leaf modules.
|
||||
|
||||
Args:
|
||||
module: input module
|
||||
eval_fn: a function for evaluating the prepared module, can be a
|
||||
function that simply runs the prepared module or a training loop
|
||||
eval_args: positional arguments for `eval_fn`
|
||||
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||
configuration, qconfig applies to all submodules of a given
|
||||
module unless qconfig for the submodules are specified(when the
|
||||
submodule already has qconfig attribute)
|
||||
|
||||
|
||||
Return:
|
||||
A quantized module
|
||||
"""
|
||||
module = prepare(module, qconfig_dict)
|
||||
eval_fn(module, eval_args)
|
||||
convert(module)
|
||||
return module
|
||||
|
||||
# Map for swapping float module to quantized ones
|
||||
DEFAULT_MODULE_MAPPING = {
|
||||
torch.nn.Linear: nnq.Linear,
|
||||
torch.nn.ReLU: nnq.ReLU,
|
||||
QuantStub: nnq.Quantize,
|
||||
}
|
||||
|
||||
def convert(module, mapping=DEFAULT_MODULE_MAPPING):
|
||||
r"""Converts the float module with observers(where we can get quantization
|
||||
parameters) to a quantized module.
|
||||
Args:
|
||||
module: calibrated module with observers
|
||||
mapping: a dictionary that maps from float module type to quantized
|
||||
module type, can be overwrritten to allow swapping user defined Modules
|
||||
Return:
|
||||
A quantized module
|
||||
"""
|
||||
module_swapped = swap_module(module, mapping)
|
||||
|
||||
reassign = {}
|
||||
for name, mod in module.named_children():
|
||||
new_mod = convert(mod, mapping)
|
||||
if new_mod is not mod:
|
||||
reassign[name] = new_mod
|
||||
|
||||
for name, mod in reassign.items():
|
||||
setattr(module_swapped, name, mod)
|
||||
|
||||
return module_swapped
|
||||
|
||||
def swap_module(mod, mapping):
|
||||
r"""Swaps the module if it has a quantized counterpart and it has an
|
||||
`observer` attached.
|
||||
|
||||
Args:
|
||||
mod: input module
|
||||
mapping: a dictionary that maps from nn module to nnq module
|
||||
|
||||
Return:
|
||||
The corresponding quantized module of `mod`
|
||||
"""
|
||||
new_mod = mod
|
||||
print('swapping:', mod)
|
||||
if hasattr(mod, 'observer'):
|
||||
if type(mod) in mapping:
|
||||
new_mod = mapping[type(mod)].from_float(mod)
|
||||
|
||||
if type(mod) == DeQuantStub:
|
||||
new_mod = nnq.DeQuantize.from_float(mod)
|
||||
|
||||
return new_mod
|
||||
Loading…
Reference in New Issue
Block a user