mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch.quantization conversion utilities, observers for eager mode quantization (#22010)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22010 torch.quantization module with observers and conversion routines Reviewed By: zafartahirov Differential Revision: D15554183 fbshipit-source-id: 05a3fabe28dd701978b8ecebf5bfc3a4c044ba5c
This commit is contained in:
parent
073fa6f411
commit
5040d52a5a
381
test/test_quantization.py
Normal file
381
test/test_quantization.py
Normal file
|
|
@ -0,0 +1,381 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
import torch
|
||||||
|
import torch.nn.quantized as nnq
|
||||||
|
import torch.quantization as tq
|
||||||
|
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
|
||||||
|
default_eval_fn, QConfig, default_qconfig, default_observer, quantize, \
|
||||||
|
prepare, convert
|
||||||
|
|
||||||
|
from common_utils import TestCase, run_tests
|
||||||
|
|
||||||
|
class SingleLayerLinearModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SingleLayerLinearModel, self).__init__()
|
||||||
|
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TwoLayerLinearModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(TwoLayerLinearModel, self).__init__()
|
||||||
|
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
|
||||||
|
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class LinearReluModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(LinearReluModel, self).__init__()
|
||||||
|
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu(self.fc(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class NestedModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(NestedModel, self).__init__()
|
||||||
|
self.sub1 = LinearReluModel()
|
||||||
|
self.sub2 = TwoLayerLinearModel()
|
||||||
|
self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.sub1(x)
|
||||||
|
x = self.sub2(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class InnerModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(InnerModule, self).__init__()
|
||||||
|
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.relu(self.fc2(self.relu(self.fc1(x))))
|
||||||
|
|
||||||
|
class WrappedModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(WrappedModel, self).__init__()
|
||||||
|
self.qconfig = default_qconfig
|
||||||
|
self.sub = QuantWrapper(InnerModule())
|
||||||
|
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||||
|
# don't quantize this fc
|
||||||
|
self.fc.qconfig = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc(self.sub(x))
|
||||||
|
|
||||||
|
class ManualQuantModel(torch.nn.Module):
|
||||||
|
r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(ManualQuantModel, self).__init__()
|
||||||
|
self.qconfig = default_qconfig
|
||||||
|
self.quant = QuantStub()
|
||||||
|
self.dequant = DeQuantStub()
|
||||||
|
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.quant(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
return self.dequant(x)
|
||||||
|
|
||||||
|
calib_data = [torch.rand(20, 5, dtype=torch.float) for _ in range(20)]
|
||||||
|
|
||||||
|
class ModelQuantizeAPITest(TestCase):
|
||||||
|
|
||||||
|
def checkNoPrepModules(self, module):
|
||||||
|
r"""Checks the module does not contain child
|
||||||
|
modules for quantization prepration, e.g.
|
||||||
|
quant, dequant and observer
|
||||||
|
"""
|
||||||
|
self.assertFalse(hasattr(module, 'quant'))
|
||||||
|
self.assertFalse(hasattr(module, 'dequant'))
|
||||||
|
|
||||||
|
def checkHasPrepModules(self, module):
|
||||||
|
r"""Checks the module contains child
|
||||||
|
modules for quantization prepration, e.g.
|
||||||
|
quant, dequant and observer
|
||||||
|
"""
|
||||||
|
self.assertTrue(hasattr(module, 'module'))
|
||||||
|
self.assertTrue(hasattr(module, 'quant'))
|
||||||
|
self.assertTrue(hasattr(module, 'dequant'))
|
||||||
|
|
||||||
|
def checkObservers(self, module):
|
||||||
|
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
|
||||||
|
self.assertTrue(hasattr(module, 'observer'))
|
||||||
|
for child in module.children():
|
||||||
|
self.checkObservers(child)
|
||||||
|
|
||||||
|
def checkQuantDequant(self, mod):
|
||||||
|
self.assertEqual(type(mod.quant), nnq.Quantize)
|
||||||
|
self.assertEqual(type(mod.dequant), nnq.DeQuantize)
|
||||||
|
|
||||||
|
def checkQuantizedLinear(self, mod):
|
||||||
|
self.assertEqual(type(mod.module), nnq.Linear)
|
||||||
|
self.assertEqual(mod.module.bias.dtype, torch.qint32)
|
||||||
|
self.checkQuantDequant(mod)
|
||||||
|
|
||||||
|
def checkLinear(self, mod):
|
||||||
|
self.assertEqual(type(mod), torch.nn.Linear)
|
||||||
|
|
||||||
|
def test_single_layer(self):
|
||||||
|
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
|
||||||
|
to nnq.Linear which is the quantized version of the module
|
||||||
|
"""
|
||||||
|
model = SingleLayerLinearModel()
|
||||||
|
qconfig_dict = {
|
||||||
|
'': default_qconfig
|
||||||
|
}
|
||||||
|
model = prepare(model, qconfig_dict)
|
||||||
|
# Check if observers and quant/dequant nodes are inserted
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkHasPrepModules(model.fc1)
|
||||||
|
self.checkObservers(model)
|
||||||
|
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkHasPrepModules(model.fc1)
|
||||||
|
self.checkQuantizedLinear(model.fc1)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(SingleLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
def test_two_layers(self):
|
||||||
|
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
|
||||||
|
`fc2`, and `fc1`is not quantized
|
||||||
|
"""
|
||||||
|
model = TwoLayerLinearModel()
|
||||||
|
qconfig_dict = {
|
||||||
|
'fc2': default_qconfig
|
||||||
|
}
|
||||||
|
model = prepare(model, qconfig_dict)
|
||||||
|
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkObservers(model)
|
||||||
|
self.checkNoPrepModules(model.fc1)
|
||||||
|
self.checkHasPrepModules(model.fc2)
|
||||||
|
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkNoPrepModules(model.fc1)
|
||||||
|
self.checkHasPrepModules(model.fc2)
|
||||||
|
self.assertEqual(type(model.fc1), torch.nn.Linear)
|
||||||
|
self.checkQuantizedLinear(model.fc2)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(TwoLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
def test_nested1(self):
|
||||||
|
r"""Test quantization for nested model, top level 'fc3' and
|
||||||
|
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
|
||||||
|
"""
|
||||||
|
model = NestedModel()
|
||||||
|
qconfig_dict = {
|
||||||
|
'fc3': default_qconfig,
|
||||||
|
'sub2.fc1': default_qconfig
|
||||||
|
}
|
||||||
|
|
||||||
|
def checkPrepModules(model, before_calib=False):
|
||||||
|
if before_calib:
|
||||||
|
self.checkObservers(model)
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkNoPrepModules(model.sub1)
|
||||||
|
self.checkNoPrepModules(model.sub1.fc)
|
||||||
|
self.checkNoPrepModules(model.sub1.relu)
|
||||||
|
self.checkNoPrepModules(model.sub2)
|
||||||
|
self.checkHasPrepModules(model.sub2.fc1)
|
||||||
|
self.checkNoPrepModules(model.sub2.fc2)
|
||||||
|
self.checkHasPrepModules(model.fc3)
|
||||||
|
|
||||||
|
model = prepare(model, qconfig_dict)
|
||||||
|
checkPrepModules(model, True)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
checkPrepModules(model)
|
||||||
|
self.checkLinear(model.sub1.fc)
|
||||||
|
self.checkQuantizedLinear(model.fc3)
|
||||||
|
self.checkQuantizedLinear(model.sub2.fc1)
|
||||||
|
self.checkLinear(model.sub2.fc2)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested2(self):
|
||||||
|
r"""Another test case for quantized, we will quantize all submodules
|
||||||
|
of submodule sub2, this will include redundant quant/dequant, to
|
||||||
|
remove them we need to manually call QuantWrapper or insert
|
||||||
|
QuantStub/DeQuantStub, see `test_quant_dequant_wrapper` and
|
||||||
|
`test_manual`
|
||||||
|
"""
|
||||||
|
model = NestedModel()
|
||||||
|
qconfig_dict = {
|
||||||
|
'fc3': default_qconfig,
|
||||||
|
'sub2': default_qconfig
|
||||||
|
}
|
||||||
|
model = prepare(model, qconfig_dict)
|
||||||
|
|
||||||
|
def checkPrepModules(model, before_calib=False):
|
||||||
|
if before_calib:
|
||||||
|
self.checkObservers(model)
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkNoPrepModules(model.sub1)
|
||||||
|
self.checkNoPrepModules(model.sub1.fc)
|
||||||
|
self.checkNoPrepModules(model.sub1.relu)
|
||||||
|
self.checkNoPrepModules(model.sub2)
|
||||||
|
self.checkHasPrepModules(model.sub2.fc1)
|
||||||
|
self.checkHasPrepModules(model.sub2.fc2)
|
||||||
|
self.checkHasPrepModules(model.fc3)
|
||||||
|
|
||||||
|
checkPrepModules(model, True)
|
||||||
|
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
checkPrepModules(model)
|
||||||
|
self.checkLinear(model.sub1.fc)
|
||||||
|
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
|
||||||
|
self.checkQuantizedLinear(model.sub2.fc1)
|
||||||
|
self.checkQuantizedLinear(model.sub2.fc2)
|
||||||
|
self.checkQuantizedLinear(model.fc3)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
def test_nested3(self):
|
||||||
|
r"""More complicated nested test case with child qconfig overrides
|
||||||
|
parent qconfig
|
||||||
|
"""
|
||||||
|
model = NestedModel()
|
||||||
|
custum_options = {
|
||||||
|
'dtype': torch.quint8,
|
||||||
|
'qscheme': torch.per_tensor_affine
|
||||||
|
}
|
||||||
|
custom_qconfig = QConfig(weight=default_observer(),
|
||||||
|
activation=default_observer(**custum_options))
|
||||||
|
qconfig_dict = {
|
||||||
|
'fc3': default_qconfig,
|
||||||
|
'sub2': default_qconfig,
|
||||||
|
'sub2.fc1': custom_qconfig
|
||||||
|
}
|
||||||
|
model = prepare(model, qconfig_dict)
|
||||||
|
|
||||||
|
def checkPrepModules(model, before_calib=False):
|
||||||
|
if before_calib:
|
||||||
|
self.checkObservers(model)
|
||||||
|
self.checkNoPrepModules(model)
|
||||||
|
self.checkNoPrepModules(model.sub1)
|
||||||
|
self.checkNoPrepModules(model.sub1.fc)
|
||||||
|
self.checkNoPrepModules(model.sub1.relu)
|
||||||
|
self.checkNoPrepModules(model.sub2)
|
||||||
|
self.checkHasPrepModules(model.sub2.fc1)
|
||||||
|
self.checkHasPrepModules(model.sub2.fc2)
|
||||||
|
self.checkHasPrepModules(model.fc3)
|
||||||
|
|
||||||
|
checkPrepModules(model, True)
|
||||||
|
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
checkPrepModules(model)
|
||||||
|
self.checkQuantizedLinear(model.sub2.fc1)
|
||||||
|
self.checkQuantizedLinear(model.sub2.fc2)
|
||||||
|
self.checkQuantizedLinear(model.fc3)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
def test_quant_wrapper(self):
|
||||||
|
r"""User need to modify the original code with QuantWrapper,
|
||||||
|
and call the quantization utility functions.
|
||||||
|
"""
|
||||||
|
model = WrappedModel()
|
||||||
|
|
||||||
|
# since we didn't provide qconfig_dict, the model is modified inplace
|
||||||
|
# but we can do `model = prepare(model)` as well
|
||||||
|
prepare(model)
|
||||||
|
self.checkObservers(model)
|
||||||
|
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
self.checkLinear(model.fc)
|
||||||
|
self.checkQuantDequant(model.sub)
|
||||||
|
self.assertEqual(type(model.sub.module.fc1), nnq.Linear)
|
||||||
|
self.assertEqual(type(model.sub.module.fc2), nnq.Linear)
|
||||||
|
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(WrappedModel(), default_eval_fn, calib_data, {})
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_manual(self):
|
||||||
|
r"""User inserts QuantStub and DeQuantStub in model code
|
||||||
|
and call the quantization utility functions.
|
||||||
|
"""
|
||||||
|
model = ManualQuantModel()
|
||||||
|
# propagate the qconfig of parents to children, model is changed
|
||||||
|
# inplace
|
||||||
|
prepare(model)
|
||||||
|
self.checkObservers(model)
|
||||||
|
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
convert(model)
|
||||||
|
|
||||||
|
def checkQuantized(model):
|
||||||
|
self.assertEqual(type(model.fc), nnq.Linear)
|
||||||
|
default_eval_fn(model, calib_data)
|
||||||
|
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
# test one line API
|
||||||
|
model = quantize(ManualQuantModel(), default_eval_fn, calib_data)
|
||||||
|
checkQuantized(model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_tests()
|
||||||
|
|
@ -4,9 +4,9 @@ from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ...modules.module import Module
|
from ...modules.activation import ReLU as NNReLU
|
||||||
|
|
||||||
class ReLU(Module):
|
class ReLU(NNReLU):
|
||||||
r"""Applies quantized rectified linear unit function element-wise:
|
r"""Applies quantized rectified linear unit function element-wise:
|
||||||
|
|
||||||
:math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point.
|
:math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point.
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ class Quantize(Module):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod):
|
||||||
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
|
assert hasattr(mod, 'observer')
|
||||||
|
qparams = mod.observer.calculate_qparams()
|
||||||
|
return Quantize(qparams[0].item(), qparams[1].item(), mod.observer.dtype)
|
||||||
|
|
||||||
class DeQuantize(Module):
|
class DeQuantize(Module):
|
||||||
r"""Dequantizes an incoming tensor
|
r"""Dequantizes an incoming tensor
|
||||||
|
|
@ -136,3 +138,30 @@ class Linear(NNLinear):
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# TODO: support initializing from quantization parameters when Quantizer is
|
||||||
|
# exposed in python
|
||||||
|
@staticmethod
|
||||||
|
def from_float(mod):
|
||||||
|
r"""Create a quantized module from a float module or qparams_dict
|
||||||
|
|
||||||
|
Args: `mod` a float module, either produced by torch.quantization utilities
|
||||||
|
or directly from user
|
||||||
|
"""
|
||||||
|
assert type(mod) == NNLinear, 'nnq.Linear.from_float only works for nn.Linear'
|
||||||
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||||
|
assert hasattr(mod, 'observer'), 'Input float module must have observer attached'
|
||||||
|
activation_observer = mod.observer
|
||||||
|
act_qparams = activation_observer.calculate_qparams()
|
||||||
|
weight_observer = mod.qconfig.weight()
|
||||||
|
weight_observer(mod.weight)
|
||||||
|
wt_qparams = weight_observer.calculate_qparams()
|
||||||
|
bias_scale = (wt_qparams[0] * act_qparams[0]).float()
|
||||||
|
qweight = torch.quantize_linear(mod.weight.float(), wt_qparams[0], wt_qparams[1].long().item(), torch.qint8)
|
||||||
|
qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
|
||||||
|
qlinear = Linear(mod.in_features, mod.out_features)
|
||||||
|
qlinear._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(qweight)
|
||||||
|
qlinear.bias = qbias
|
||||||
|
qlinear.out_scale = torch.tensor([act_qparams[0]])
|
||||||
|
qlinear.out_zero_point = torch.tensor([act_qparams[1]])
|
||||||
|
return qlinear
|
||||||
|
|
|
||||||
9
torch/quantization/QConfig.py
Normal file
9
torch/quantization/QConfig.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
from collections import namedtuple
|
||||||
|
from .observer import *
|
||||||
|
|
||||||
|
QConfig = namedtuple('QConfig',
|
||||||
|
['weight', 'activation'])
|
||||||
|
|
||||||
|
default_qconfig = QConfig(default_weight_observer(),
|
||||||
|
default_observer())
|
||||||
28
torch/quantization/__init__.py
Normal file
28
torch/quantization/__init__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
from .quantize import * # noqa: F401
|
||||||
|
from .observer import * # noqa: F401
|
||||||
|
from .QConfig import * # noqa: F401
|
||||||
|
|
||||||
|
def default_eval_fn(model, calib_data):
|
||||||
|
r"""
|
||||||
|
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
||||||
|
input Tensors and run the model on the dataset
|
||||||
|
"""
|
||||||
|
for data in calib_data:
|
||||||
|
model(data)
|
||||||
|
|
||||||
|
_all__ = [
|
||||||
|
'QuantWrapper', 'QuantStub', 'DeQuantStub', 'DEFAULT_MODULE_MAPPING',
|
||||||
|
# Top level API for quantizing a float model
|
||||||
|
'quantize',
|
||||||
|
# Sub functions called by quantize
|
||||||
|
'prepare', 'convert',
|
||||||
|
# Sub functions for `prepare` and `swap_module`
|
||||||
|
'propagate_qconfig', 'add_quant_dequant', 'add_observer', 'swap_module',
|
||||||
|
'default_eval_fn',
|
||||||
|
# Observers
|
||||||
|
'Observer', 'WeightObserver', 'observer', 'default_observer',
|
||||||
|
'default_weight_observer',
|
||||||
|
# QConfig
|
||||||
|
'QConfig', 'default_qconfig'
|
||||||
|
]
|
||||||
73
torch/quantization/observer.py
Normal file
73
torch/quantization/observer.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
class Observer(nn.Module):
|
||||||
|
r"""Default Observer Module
|
||||||
|
A default implementation of the observer module, only works for
|
||||||
|
`per_tensor_affine` quantization scheme.
|
||||||
|
The module will record the running average of max and min value of the
|
||||||
|
observed Tensor and calulate_qparams will calculate the scale and zero_point
|
||||||
|
|
||||||
|
Other types of Observers should follow the same API, it can take arbitrary
|
||||||
|
number of keyward arguments. In forward, it will update the statistics of
|
||||||
|
the observed Tensor. And it should provide a `calculate_qparam` function
|
||||||
|
that computes the quantization parameters given the collected statistics.
|
||||||
|
TODO: Maybe add an abstract Observer class that enforces these rules?
|
||||||
|
"""
|
||||||
|
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine):
|
||||||
|
super(Observer, self).__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.qscheme = qscheme
|
||||||
|
assert self.qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric), \
|
||||||
|
'Default Observer only works for per_tensor_affine and \
|
||||||
|
per_tensor_symmetric quantization scheme'
|
||||||
|
assert self.dtype in (torch.qint8, torch.quint8), \
|
||||||
|
'Default Observer only works for qint8 and quint data type'
|
||||||
|
self.min_val = None
|
||||||
|
self.max_val = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.min_val is None or self.max_val is None:
|
||||||
|
self.min_val = torch.min(x)
|
||||||
|
self.max_val = torch.max(x)
|
||||||
|
else:
|
||||||
|
self.min_val = torch.min(torch.min(x), self.min_val)
|
||||||
|
self.max_val = torch.max(torch.max(x), self.max_val)
|
||||||
|
|
||||||
|
def calculate_qparams(self):
|
||||||
|
if self.dtype == torch.qint8:
|
||||||
|
qmin, qmax = -128, 127
|
||||||
|
else:
|
||||||
|
qmin, qmax = 0, 255
|
||||||
|
n_levels = 255.0
|
||||||
|
if self.max_val is None or self.min_val is None:
|
||||||
|
raise Exception('must run observer before calling calculate_qparams!')
|
||||||
|
max_val, min_val = self.max_val.item(), self.min_val.item()
|
||||||
|
if max_val == min_val:
|
||||||
|
scale = 1.0
|
||||||
|
zero_point = 0
|
||||||
|
else:
|
||||||
|
if self.qscheme == torch.per_tensor_symmetric:
|
||||||
|
max_val = max(-min_val, max_val)
|
||||||
|
scale = max_val / 127.0
|
||||||
|
zero_point = 0 if self.dtype == torch.qint8 else 128
|
||||||
|
else:
|
||||||
|
scale = (max_val - min_val) / n_levels
|
||||||
|
zero_point = qmin - round(min_val / scale)
|
||||||
|
zero_point = max(qmin, zero_point)
|
||||||
|
zero_point = min(qmax, zero_point)
|
||||||
|
|
||||||
|
return torch.tensor([scale, zero_point])
|
||||||
|
|
||||||
|
def observer(observer_cls, **kwargs):
|
||||||
|
return partial(observer_cls, **kwargs)
|
||||||
|
|
||||||
|
def default_observer(**kwargs):
|
||||||
|
return observer(Observer, **kwargs)
|
||||||
|
|
||||||
|
def default_weight_observer(**kwargs):
|
||||||
|
kwargs.setdefault('dtype', torch.qint8)
|
||||||
|
kwargs.setdefault('qscheme', torch.per_tensor_symmetric)
|
||||||
|
return observer(Observer, **kwargs)
|
||||||
255
torch/quantization/quantize.py
Normal file
255
torch/quantization/quantize.py
Normal file
|
|
@ -0,0 +1,255 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.quantized as nnq
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix=''):
|
||||||
|
r"""This is a helper function for `propagate_qconfig`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: input module
|
||||||
|
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||||
|
configuration
|
||||||
|
qconfig_parent: quantization config of parent module, we will fallback to
|
||||||
|
this config when there is no specified config for current
|
||||||
|
module
|
||||||
|
prefix: corresponding prefix of the current module, used as key in
|
||||||
|
qconfig_dict
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None, module is modified inplace with qconfig attached
|
||||||
|
"""
|
||||||
|
if not hasattr(module, 'qconfig'):
|
||||||
|
module.qconfig = None
|
||||||
|
if qconfig_dict and prefix in qconfig_dict:
|
||||||
|
module.qconfig = qconfig_dict[prefix]
|
||||||
|
else:
|
||||||
|
module.qconfig = qconfig_parent
|
||||||
|
print('prefix:', prefix, 'qconfig: ', module.qconfig)
|
||||||
|
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module_prefix = prefix + '.' + name if prefix else name
|
||||||
|
propagate_qconfig_helper(child, qconfig_dict, module.qconfig, module_prefix)
|
||||||
|
|
||||||
|
def propagate_qconfig(module, qconfig_dict=None):
|
||||||
|
r"""Propagate qconfig through the module hierarchy and assign `qconfig`
|
||||||
|
attribute on each leaf module
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: input module
|
||||||
|
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||||
|
configuration, qconfig applies to all submodules of a given
|
||||||
|
module unless qconfig for the submodules are specified(when the
|
||||||
|
submodule already has qconfig attribute)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None, module is modified inplace with qconfig attached
|
||||||
|
"""
|
||||||
|
if qconfig_dict is None:
|
||||||
|
qconfig_dict = {}
|
||||||
|
propagate_qconfig_helper(module, qconfig_dict)
|
||||||
|
|
||||||
|
def _observer_forward_hook(self, input, output):
|
||||||
|
r"""Forward hook that calls observer on the output
|
||||||
|
"""
|
||||||
|
self.observer(output)
|
||||||
|
|
||||||
|
# TODO(jerryzh): remove_observer?
|
||||||
|
def add_observer(module):
|
||||||
|
r"""Add observer for the leaf child of the module.
|
||||||
|
|
||||||
|
This function insert observer module to all leaf child module that
|
||||||
|
has a valid qconfig attribute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: input module with qconfig attributes for all the leaf modules
|
||||||
|
that we want to quantize
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None, module is modified inplace with added observer modules and
|
||||||
|
forward_hooks
|
||||||
|
"""
|
||||||
|
for child in module.children():
|
||||||
|
add_observer(child)
|
||||||
|
|
||||||
|
# Insert observers only for leaf nodes, note that this observer is for
|
||||||
|
# the output of the module, for input QuantStub will observe them
|
||||||
|
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
|
||||||
|
# observer and hook will be gone after we swap the module
|
||||||
|
module.add_module('observer', module.qconfig.activation())
|
||||||
|
module.register_forward_hook(_observer_forward_hook)
|
||||||
|
|
||||||
|
class QuantWrapper(nn.Module):
|
||||||
|
r"""A wrapper class that wraps the input module, adds QuantStub and
|
||||||
|
DeQuantStub and surround the call to module with call to quant and dequant
|
||||||
|
modules.
|
||||||
|
|
||||||
|
This is used by the `quantization` utility functions to add the quant and
|
||||||
|
dequant modules, before `convert` function `QuantStub` will just be observer,
|
||||||
|
it observes the input tensor, after `convert`, `QuantStub`
|
||||||
|
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
|
||||||
|
for `DeQuantStub`.
|
||||||
|
"""
|
||||||
|
def __init__(self, module):
|
||||||
|
super(QuantWrapper, self).__init__()
|
||||||
|
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
|
||||||
|
self.quant = QuantStub(qconfig)
|
||||||
|
self.dequant = DeQuantStub()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
X = self.quant(X)
|
||||||
|
X = self.module(X)
|
||||||
|
return self.dequant(X)
|
||||||
|
|
||||||
|
def add_quant_dequant(module):
|
||||||
|
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
|
||||||
|
Note that this function will modify the children of module inplace and it
|
||||||
|
can return a new module which wraps the input module as well.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: input module with qconfig attributes for all the leaf modules
|
||||||
|
that we want to quantize
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Either the inplace modified module with submodules wrapped in
|
||||||
|
`QuantWrapper` based on qconfig or a new `QuantWrapper` module which
|
||||||
|
wraps the input module, the latter case only happens when the input
|
||||||
|
module is a leaf module and we want to quantize it.
|
||||||
|
"""
|
||||||
|
if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
|
||||||
|
return QuantWrapper(module)
|
||||||
|
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module._modules[name] = add_quant_dequant(child)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def prepare(module, qconfig_dict=None):
|
||||||
|
r"""Prepares the module for calibration or training given a qconfig_dict.
|
||||||
|
Note that the module will be modified inplace but in case the input module
|
||||||
|
is a leaf module, a wrapped module will be returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mod: input module
|
||||||
|
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||||
|
configuration
|
||||||
|
Return:
|
||||||
|
A module with qconfig propogated, observer and quant dequant or fake
|
||||||
|
quant modules attached, a module that is ready for calibration or
|
||||||
|
training
|
||||||
|
"""
|
||||||
|
propagate_qconfig(module, qconfig_dict)
|
||||||
|
if qconfig_dict:
|
||||||
|
module = add_quant_dequant(module)
|
||||||
|
add_observer(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
class QuantStub(nn.Module):
|
||||||
|
r"""Quantize stub module, before calibration, this is same as an observer,
|
||||||
|
it will be swapped as `nnq.Quantize` in `convert`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qconfig: quantization configuration for the tensor,
|
||||||
|
if qconfig is not provided, we will get qconfig from parent modules
|
||||||
|
"""
|
||||||
|
def __init__(self, qconfig=None):
|
||||||
|
super(QuantStub, self).__init__()
|
||||||
|
if qconfig:
|
||||||
|
self.qconfig = qconfig
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
class DeQuantStub(nn.Module):
|
||||||
|
r"""Dequantize stub module, before calibration, this is same as identity,
|
||||||
|
this will be swapped as `nnq.DeQuantize` in `convert`.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(DeQuantStub, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def quantize(module, eval_fn, eval_args, qconfig_dict=None):
|
||||||
|
r"""Converts a float module to quantized module.
|
||||||
|
|
||||||
|
First it will prepare the module for calibration or training, then it calls
|
||||||
|
`eval_fn` which will run the calibration step or training step,
|
||||||
|
after that we will call `convert` which will convert the module to a
|
||||||
|
quantized module.
|
||||||
|
|
||||||
|
When `qconfig_dict` is None or empty dictionary, we will assume user will
|
||||||
|
insert quant/dequant stubs and add qconfig in approporiate places.
|
||||||
|
When `qconfig_dict` is not None or empty dictionary, we will add quant/dequant
|
||||||
|
stubs using QuantWrapper for all the leaf modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: input module
|
||||||
|
eval_fn: a function for evaluating the prepared module, can be a
|
||||||
|
function that simply runs the prepared module or a training loop
|
||||||
|
eval_args: positional arguments for `eval_fn`
|
||||||
|
qconfig_dict: dictionary that maps from name of submodule to quantization
|
||||||
|
configuration, qconfig applies to all submodules of a given
|
||||||
|
module unless qconfig for the submodules are specified(when the
|
||||||
|
submodule already has qconfig attribute)
|
||||||
|
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A quantized module
|
||||||
|
"""
|
||||||
|
module = prepare(module, qconfig_dict)
|
||||||
|
eval_fn(module, eval_args)
|
||||||
|
convert(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
# Map for swapping float module to quantized ones
|
||||||
|
DEFAULT_MODULE_MAPPING = {
|
||||||
|
torch.nn.Linear: nnq.Linear,
|
||||||
|
torch.nn.ReLU: nnq.ReLU,
|
||||||
|
QuantStub: nnq.Quantize,
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert(module, mapping=DEFAULT_MODULE_MAPPING):
|
||||||
|
r"""Converts the float module with observers(where we can get quantization
|
||||||
|
parameters) to a quantized module.
|
||||||
|
Args:
|
||||||
|
module: calibrated module with observers
|
||||||
|
mapping: a dictionary that maps from float module type to quantized
|
||||||
|
module type, can be overwrritten to allow swapping user defined Modules
|
||||||
|
Return:
|
||||||
|
A quantized module
|
||||||
|
"""
|
||||||
|
module_swapped = swap_module(module, mapping)
|
||||||
|
|
||||||
|
reassign = {}
|
||||||
|
for name, mod in module.named_children():
|
||||||
|
new_mod = convert(mod, mapping)
|
||||||
|
if new_mod is not mod:
|
||||||
|
reassign[name] = new_mod
|
||||||
|
|
||||||
|
for name, mod in reassign.items():
|
||||||
|
setattr(module_swapped, name, mod)
|
||||||
|
|
||||||
|
return module_swapped
|
||||||
|
|
||||||
|
def swap_module(mod, mapping):
|
||||||
|
r"""Swaps the module if it has a quantized counterpart and it has an
|
||||||
|
`observer` attached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mod: input module
|
||||||
|
mapping: a dictionary that maps from nn module to nnq module
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The corresponding quantized module of `mod`
|
||||||
|
"""
|
||||||
|
new_mod = mod
|
||||||
|
print('swapping:', mod)
|
||||||
|
if hasattr(mod, 'observer'):
|
||||||
|
if type(mod) in mapping:
|
||||||
|
new_mod = mapping[type(mod)].from_float(mod)
|
||||||
|
|
||||||
|
if type(mod) == DeQuantStub:
|
||||||
|
new_mod = nnq.DeQuantize.from_float(mod)
|
||||||
|
|
||||||
|
return new_mod
|
||||||
Loading…
Reference in New Issue
Block a user