torch.quantization conversion utilities, observers for eager mode quantization (#22010)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22010

torch.quantization module with observers and conversion routines

Reviewed By: zafartahirov

Differential Revision: D15554183

fbshipit-source-id: 05a3fabe28dd701978b8ecebf5bfc3a4c044ba5c
This commit is contained in:
Jerry Zhang 2019-07-09 10:42:35 -07:00 committed by Facebook Github Bot
parent 073fa6f411
commit 5040d52a5a
7 changed files with 778 additions and 3 deletions

381
test/test_quantization.py Normal file
View File

@ -0,0 +1,381 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn.quantized as nnq
import torch.quantization as tq
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_eval_fn, QConfig, default_qconfig, default_observer, quantize, \
prepare, convert
from common_utils import TestCase, run_tests
class SingleLayerLinearModel(torch.nn.Module):
def __init__(self):
super(SingleLayerLinearModel, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = self.fc1(x)
return x
class TwoLayerLinearModel(torch.nn.Module):
def __init__(self):
super(TwoLayerLinearModel, self).__init__()
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
class LinearReluModel(torch.nn.Module):
def __init__(self):
super(LinearReluModel, self).__init__()
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(self.fc(x))
return x
class NestedModel(torch.nn.Module):
def __init__(self):
super(NestedModel, self).__init__()
self.sub1 = LinearReluModel()
self.sub2 = TwoLayerLinearModel()
self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = self.sub1(x)
x = self.sub2(x)
x = self.fc3(x)
return x
class InnerModule(torch.nn.Module):
def __init__(self):
super(InnerModule, self).__init__()
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
def forward(self, x):
return self.relu(self.fc2(self.relu(self.fc1(x))))
class WrappedModel(torch.nn.Module):
def __init__(self):
super(WrappedModel, self).__init__()
self.qconfig = default_qconfig
self.sub = QuantWrapper(InnerModule())
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
# don't quantize this fc
self.fc.qconfig = None
def forward(self, x):
return self.fc(self.sub(x))
class ManualQuantModel(torch.nn.Module):
r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
"""
def __init__(self):
super(ManualQuantModel, self).__init__()
self.qconfig = default_qconfig
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = self.quant(x)
x = self.fc(x)
return self.dequant(x)
calib_data = [torch.rand(20, 5, dtype=torch.float) for _ in range(20)]
class ModelQuantizeAPITest(TestCase):
def checkNoPrepModules(self, module):
r"""Checks the module does not contain child
modules for quantization prepration, e.g.
quant, dequant and observer
"""
self.assertFalse(hasattr(module, 'quant'))
self.assertFalse(hasattr(module, 'dequant'))
def checkHasPrepModules(self, module):
r"""Checks the module contains child
modules for quantization prepration, e.g.
quant, dequant and observer
"""
self.assertTrue(hasattr(module, 'module'))
self.assertTrue(hasattr(module, 'quant'))
self.assertTrue(hasattr(module, 'dequant'))
def checkObservers(self, module):
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
self.assertTrue(hasattr(module, 'observer'))
for child in module.children():
self.checkObservers(child)
def checkQuantDequant(self, mod):
self.assertEqual(type(mod.quant), nnq.Quantize)
self.assertEqual(type(mod.dequant), nnq.DeQuantize)
def checkQuantizedLinear(self, mod):
self.assertEqual(type(mod.module), nnq.Linear)
self.assertEqual(mod.module.bias.dtype, torch.qint32)
self.checkQuantDequant(mod)
def checkLinear(self, mod):
self.assertEqual(type(mod), torch.nn.Linear)
def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
"""
model = SingleLayerLinearModel()
qconfig_dict = {
'': default_qconfig
}
model = prepare(model, qconfig_dict)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkObservers(model)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkQuantizedLinear(model.fc1)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(SingleLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
model = TwoLayerLinearModel()
qconfig_dict = {
'fc2': default_qconfig
}
model = prepare(model, qconfig_dict)
self.checkNoPrepModules(model)
self.checkObservers(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkQuantizedLinear(model.fc2)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(TwoLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
model = NestedModel()
qconfig_dict = {
'fc3': default_qconfig,
'sub2.fc1': default_qconfig
}
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkNoPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
model = prepare(model, qconfig_dict)
checkPrepModules(model, True)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.checkQuantizedLinear(model.fc3)
self.checkQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_nested2(self):
r"""Another test case for quantized, we will quantize all submodules
of submodule sub2, this will include redundant quant/dequant, to
remove them we need to manually call QuantWrapper or insert
QuantStub/DeQuantStub, see `test_quant_dequant_wrapper` and
`test_manual`
"""
model = NestedModel()
qconfig_dict = {
'fc3': default_qconfig,
'sub2': default_qconfig
}
model = prepare(model, qconfig_dict)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkQuantizedLinear(model.sub2.fc1)
self.checkQuantizedLinear(model.sub2.fc2)
self.checkQuantizedLinear(model.fc3)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
model = NestedModel()
custum_options = {
'dtype': torch.quint8,
'qscheme': torch.per_tensor_affine
}
custom_qconfig = QConfig(weight=default_observer(),
activation=default_observer(**custum_options))
qconfig_dict = {
'fc3': default_qconfig,
'sub2': default_qconfig,
'sub2.fc1': custom_qconfig
}
model = prepare(model, qconfig_dict)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkQuantizedLinear(model.sub2.fc1)
self.checkQuantizedLinear(model.sub2.fc2)
self.checkQuantizedLinear(model.fc3)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_quant_wrapper(self):
r"""User need to modify the original code with QuantWrapper,
and call the quantization utility functions.
"""
model = WrappedModel()
# since we didn't provide qconfig_dict, the model is modified inplace
# but we can do `model = prepare(model)` as well
prepare(model)
self.checkObservers(model)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.checkLinear(model.fc)
self.checkQuantDequant(model.sub)
self.assertEqual(type(model.sub.module.fc1), nnq.Linear)
self.assertEqual(type(model.sub.module.fc2), nnq.Linear)
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(WrappedModel(), default_eval_fn, calib_data, {})
checkQuantized(model)
def test_manual(self):
r"""User inserts QuantStub and DeQuantStub in model code
and call the quantization utility functions.
"""
model = ManualQuantModel()
# propagate the qconfig of parents to children, model is changed
# inplace
prepare(model)
self.checkObservers(model)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc), nnq.Linear)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(ManualQuantModel(), default_eval_fn, calib_data)
checkQuantized(model)
if __name__ == '__main__':
run_tests()

View File

@ -4,9 +4,9 @@ from __future__ import print_function
from __future__ import unicode_literals
from .. import functional as F
from ...modules.module import Module
from ...modules.activation import ReLU as NNReLU
class ReLU(Module):
class ReLU(NNReLU):
r"""Applies quantized rectified linear unit function element-wise:
:math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point.

View File

@ -35,7 +35,9 @@ class Quantize(Module):
@staticmethod
def from_float(mod):
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
assert hasattr(mod, 'observer')
qparams = mod.observer.calculate_qparams()
return Quantize(qparams[0].item(), qparams[1].item(), mod.observer.dtype)
class DeQuantize(Module):
r"""Dequantizes an incoming tensor
@ -136,3 +138,30 @@ class Linear(NNLinear):
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
return
# TODO: support initializing from quantization parameters when Quantizer is
# exposed in python
@staticmethod
def from_float(mod):
r"""Create a quantized module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == NNLinear, 'nnq.Linear.from_float only works for nn.Linear'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert hasattr(mod, 'observer'), 'Input float module must have observer attached'
activation_observer = mod.observer
act_qparams = activation_observer.calculate_qparams()
weight_observer = mod.qconfig.weight()
weight_observer(mod.weight)
wt_qparams = weight_observer.calculate_qparams()
bias_scale = (wt_qparams[0] * act_qparams[0]).float()
qweight = torch.quantize_linear(mod.weight.float(), wt_qparams[0], wt_qparams[1].long().item(), torch.qint8)
qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32)
qlinear = Linear(mod.in_features, mod.out_features)
qlinear._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(qweight)
qlinear.bias = qbias
qlinear.out_scale = torch.tensor([act_qparams[0]])
qlinear.out_zero_point = torch.tensor([act_qparams[1]])
return qlinear

View File

@ -0,0 +1,9 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import namedtuple
from .observer import *
QConfig = namedtuple('QConfig',
['weight', 'activation'])
default_qconfig = QConfig(default_weight_observer(),
default_observer())

View File

@ -0,0 +1,28 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from .quantize import * # noqa: F401
from .observer import * # noqa: F401
from .QConfig import * # noqa: F401
def default_eval_fn(model, calib_data):
r"""
Default evaluation function takes a torch.utils.data.Dataset or a list of
input Tensors and run the model on the dataset
"""
for data in calib_data:
model(data)
_all__ = [
'QuantWrapper', 'QuantStub', 'DeQuantStub', 'DEFAULT_MODULE_MAPPING',
# Top level API for quantizing a float model
'quantize',
# Sub functions called by quantize
'prepare', 'convert',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig', 'add_quant_dequant', 'add_observer', 'swap_module',
'default_eval_fn',
# Observers
'Observer', 'WeightObserver', 'observer', 'default_observer',
'default_weight_observer',
# QConfig
'QConfig', 'default_qconfig'
]

View File

@ -0,0 +1,73 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
import torch
from functools import partial
class Observer(nn.Module):
r"""Default Observer Module
A default implementation of the observer module, only works for
`per_tensor_affine` quantization scheme.
The module will record the running average of max and min value of the
observed Tensor and calulate_qparams will calculate the scale and zero_point
Other types of Observers should follow the same API, it can take arbitrary
number of keyward arguments. In forward, it will update the statistics of
the observed Tensor. And it should provide a `calculate_qparam` function
that computes the quantization parameters given the collected statistics.
TODO: Maybe add an abstract Observer class that enforces these rules?
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine):
super(Observer, self).__init__()
self.dtype = dtype
self.qscheme = qscheme
assert self.qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric), \
'Default Observer only works for per_tensor_affine and \
per_tensor_symmetric quantization scheme'
assert self.dtype in (torch.qint8, torch.quint8), \
'Default Observer only works for qint8 and quint data type'
self.min_val = None
self.max_val = None
def forward(self, x):
if self.min_val is None or self.max_val is None:
self.min_val = torch.min(x)
self.max_val = torch.max(x)
else:
self.min_val = torch.min(torch.min(x), self.min_val)
self.max_val = torch.max(torch.max(x), self.max_val)
def calculate_qparams(self):
if self.dtype == torch.qint8:
qmin, qmax = -128, 127
else:
qmin, qmax = 0, 255
n_levels = 255.0
if self.max_val is None or self.min_val is None:
raise Exception('must run observer before calling calculate_qparams!')
max_val, min_val = self.max_val.item(), self.min_val.item()
if max_val == min_val:
scale = 1.0
zero_point = 0
else:
if self.qscheme == torch.per_tensor_symmetric:
max_val = max(-min_val, max_val)
scale = max_val / 127.0
zero_point = 0 if self.dtype == torch.qint8 else 128
else:
scale = (max_val - min_val) / n_levels
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
return torch.tensor([scale, zero_point])
def observer(observer_cls, **kwargs):
return partial(observer_cls, **kwargs)
def default_observer(**kwargs):
return observer(Observer, **kwargs)
def default_weight_observer(**kwargs):
kwargs.setdefault('dtype', torch.qint8)
kwargs.setdefault('qscheme', torch.per_tensor_symmetric)
return observer(Observer, **kwargs)

View File

@ -0,0 +1,255 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
import torch.nn.quantized as nnq
import torch
def propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix=''):
r"""This is a helper function for `propagate_qconfig`
Args:
module: input module
qconfig_dict: dictionary that maps from name of submodule to quantization
configuration
qconfig_parent: quantization config of parent module, we will fallback to
this config when there is no specified config for current
module
prefix: corresponding prefix of the current module, used as key in
qconfig_dict
Return:
None, module is modified inplace with qconfig attached
"""
if not hasattr(module, 'qconfig'):
module.qconfig = None
if qconfig_dict and prefix in qconfig_dict:
module.qconfig = qconfig_dict[prefix]
else:
module.qconfig = qconfig_parent
print('prefix:', prefix, 'qconfig: ', module.qconfig)
for name, child in module.named_children():
module_prefix = prefix + '.' + name if prefix else name
propagate_qconfig_helper(child, qconfig_dict, module.qconfig, module_prefix)
def propagate_qconfig(module, qconfig_dict=None):
r"""Propagate qconfig through the module hierarchy and assign `qconfig`
attribute on each leaf module
Args:
module: input module
qconfig_dict: dictionary that maps from name of submodule to quantization
configuration, qconfig applies to all submodules of a given
module unless qconfig for the submodules are specified(when the
submodule already has qconfig attribute)
Return:
None, module is modified inplace with qconfig attached
"""
if qconfig_dict is None:
qconfig_dict = {}
propagate_qconfig_helper(module, qconfig_dict)
def _observer_forward_hook(self, input, output):
r"""Forward hook that calls observer on the output
"""
self.observer(output)
# TODO(jerryzh): remove_observer?
def add_observer(module):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
has a valid qconfig attribute.
Args:
module: input module with qconfig attributes for all the leaf modules
that we want to quantize
Return:
None, module is modified inplace with added observer modules and
forward_hooks
"""
for child in module.children():
add_observer(child)
# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
# observer and hook will be gone after we swap the module
module.add_module('observer', module.qconfig.activation())
module.register_forward_hook(_observer_forward_hook)
class QuantWrapper(nn.Module):
r"""A wrapper class that wraps the input module, adds QuantStub and
DeQuantStub and surround the call to module with call to quant and dequant
modules.
This is used by the `quantization` utility functions to add the quant and
dequant modules, before `convert` function `QuantStub` will just be observer,
it observes the input tensor, after `convert`, `QuantStub`
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
for `DeQuantStub`.
"""
def __init__(self, module):
super(QuantWrapper, self).__init__()
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
self.quant = QuantStub(qconfig)
self.dequant = DeQuantStub()
self.module = module
def forward(self, X):
X = self.quant(X)
X = self.module(X)
return self.dequant(X)
def add_quant_dequant(module):
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
Note that this function will modify the children of module inplace and it
can return a new module which wraps the input module as well.
Args:
module: input module with qconfig attributes for all the leaf modules
that we want to quantize
Return:
Either the inplace modified module with submodules wrapped in
`QuantWrapper` based on qconfig or a new `QuantWrapper` module which
wraps the input module, the latter case only happens when the input
module is a leaf module and we want to quantize it.
"""
if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
return QuantWrapper(module)
for name, child in module.named_children():
module._modules[name] = add_quant_dequant(child)
return module
def prepare(module, qconfig_dict=None):
r"""Prepares the module for calibration or training given a qconfig_dict.
Note that the module will be modified inplace but in case the input module
is a leaf module, a wrapped module will be returned.
Args:
mod: input module
qconfig_dict: dictionary that maps from name of submodule to quantization
configuration
Return:
A module with qconfig propogated, observer and quant dequant or fake
quant modules attached, a module that is ready for calibration or
training
"""
propagate_qconfig(module, qconfig_dict)
if qconfig_dict:
module = add_quant_dequant(module)
add_observer(module)
return module
class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,
it will be swapped as `nnq.Quantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig=None):
super(QuantStub, self).__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x):
return x
class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,
this will be swapped as `nnq.DeQuantize` in `convert`.
"""
def __init__(self):
super(DeQuantStub, self).__init__()
def forward(self, x):
return x
def quantize(module, eval_fn, eval_args, qconfig_dict=None):
r"""Converts a float module to quantized module.
First it will prepare the module for calibration or training, then it calls
`eval_fn` which will run the calibration step or training step,
after that we will call `convert` which will convert the module to a
quantized module.
When `qconfig_dict` is None or empty dictionary, we will assume user will
insert quant/dequant stubs and add qconfig in approporiate places.
When `qconfig_dict` is not None or empty dictionary, we will add quant/dequant
stubs using QuantWrapper for all the leaf modules.
Args:
module: input module
eval_fn: a function for evaluating the prepared module, can be a
function that simply runs the prepared module or a training loop
eval_args: positional arguments for `eval_fn`
qconfig_dict: dictionary that maps from name of submodule to quantization
configuration, qconfig applies to all submodules of a given
module unless qconfig for the submodules are specified(when the
submodule already has qconfig attribute)
Return:
A quantized module
"""
module = prepare(module, qconfig_dict)
eval_fn(module, eval_args)
convert(module)
return module
# Map for swapping float module to quantized ones
DEFAULT_MODULE_MAPPING = {
torch.nn.Linear: nnq.Linear,
torch.nn.ReLU: nnq.ReLU,
QuantStub: nnq.Quantize,
}
def convert(module, mapping=DEFAULT_MODULE_MAPPING):
r"""Converts the float module with observers(where we can get quantization
parameters) to a quantized module.
Args:
module: calibrated module with observers
mapping: a dictionary that maps from float module type to quantized
module type, can be overwrritten to allow swapping user defined Modules
Return:
A quantized module
"""
module_swapped = swap_module(module, mapping)
reassign = {}
for name, mod in module.named_children():
new_mod = convert(mod, mapping)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module_swapped, name, mod)
return module_swapped
def swap_module(mod, mapping):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
Args:
mod: input module
mapping: a dictionary that maps from nn module to nnq module
Return:
The corresponding quantized module of `mod`
"""
new_mod = mod
print('swapping:', mod)
if hasattr(mod, 'observer'):
if type(mod) in mapping:
new_mod = mapping[type(mod)].from_float(mod)
if type(mod) == DeQuantStub:
new_mod = nnq.DeQuantize.from_float(mod)
return new_mod