[quant][eagermode] Move custom_module registration to prepare/convert_custom_config_dict (#46293)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46293

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D24290811

fbshipit-source-id: 7d2aee98e1946c2a4268efb94443f1e5daaa793e
This commit is contained in:
Jerry Zhang 2020-10-14 12:00:34 -07:00 committed by Facebook GitHub Bot
parent 2ffb768607
commit 3ad797c937
5 changed files with 81 additions and 126 deletions

View File

@ -23,8 +23,6 @@ from torch.quantization import (
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
@ -627,9 +625,6 @@ class TestPostTrainingStatic(QuantizationTestCase):
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
return quantized
register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -670,14 +665,28 @@ class TestPostTrainingStatic(QuantizationTestCase):
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())
original_m.qconfig = default_qconfig
m = prepare(original_m)
self.checkObservers(m)
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
m = prepare(
original_m,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkObservers(m, None, prepare_custom_config_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module
# check converted/quantized model
m = convert(m)
m = convert(
m,
convert_custom_config_dict=convert_custom_config_dict)
# check if the module is properly quantized
self.assertEqual(type(m.quant), nnq.Quantize)
self.assertEqual(type(m.conv), nnq.Conv2d)

View File

@ -9,7 +9,6 @@ from .quantize_jit import *
from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
from .custom_module_class_mappings import *
def default_eval_fn(model, calib_data):
r"""
@ -41,12 +40,6 @@ _all__ = [
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
'register_observed_custom_module_mapping',
'get_observed_custom_module_class',
'register_quantized_custom_mdoule_mapping',
'get_quantized_custom_module_class',
'is_custom_module_class',
'is_observed_custom_module',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',

View File

@ -1,75 +0,0 @@
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
def register_observed_custom_module_mapping(float_custom_module_class, observed_custom_module_class):
""" Register a mapping from `float_custom_module_class` to
`observed_custom_module_class`
`observed_custom_module_class` will have a `from_float` classmethod,
which will return an observed custom module instance given
a float custom module instance.
This will be used in prepare step of post training static quantization or
quantization aware training
"""
assert hasattr(observed_custom_module_class, 'from_float'), 'from_float must be' + \
' defined in observed custom module class'
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
observed_custom_module_class
def get_observed_custom_module_class(float_custom_module_class):
""" Get the corresponding observed module class for a given
float custom module.
"""
observed_custom_module_class = \
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
assert observed_custom_module_class is not None, \
'Float Custom module class {}'.format(float_custom_module_class) + \
' does not have a corresponding observed module class'
return observed_custom_module_class
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
def register_quantized_custom_module_mapping(float_custom_module_class, quantized_custom_module_class):
""" Register a mapping from `float_custom_module_class` to `quantized_custom_module_class`
A quantized custom module class should accept quantized input and
return quantized output. (we can relax this condition in the
future if there is a need)
`quantized_custom_module_class` will have a `from_observed` classmethod,
which will return an quantized custom module instance given
a observed custom module instance.
This will be used in prepare step of post training static quantization or
quantization aware training
"""
assert hasattr(quantized_custom_module_class, 'from_observed'), 'from_observed' + \
' must be defined in quantized custom module class'
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
quantized_custom_module_class
def get_quantized_custom_module_class(float_custom_module_class):
""" Get the corresponding quantized module class for a given
float custom module.
"""
quantized_custom_module_class = \
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
assert quantized_custom_module_class is not None, \
'Float Custom module class {}'.format(float_custom_module_class) + \
' does not have a corresponding quantized module class'
return quantized_custom_module_class
def is_custom_module_class(module_class):
""" Check if a given module class is a custom module class
"""
return module_class in OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS and \
module_class in QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS
def mark_observed_custom_module(module, custom_module_class):
""" Mark a module as observed custom module, so that
it can be identified during convert step
"""
module._is_observed_custom_module = True
module._FLOAT_MODULE = custom_module_class
def is_observed_custom_module(module):
""" Check if a module is marked as observed custom module
or not
"""
return hasattr(module, '_is_observed_custom_module') and \
module._is_observed_custom_module

View File

@ -14,14 +14,6 @@ from .quantization_mappings import (get_dynamic_quant_module_mappings,
get_qat_module_mappings,
get_qconfig_propagation_list)
from .custom_module_class_mappings import (
is_custom_module_class,
get_observed_custom_module_class,
get_quantized_custom_module_class,
mark_observed_custom_module,
is_observed_custom_module,
)
from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
@ -86,7 +78,7 @@ def register_activation_post_process_hook(module):
'Expect activation_post_process attribut already attached to the module'
return module.register_forward_hook(_observer_forward_hook)
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None):
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
@ -103,6 +95,9 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
if qconfig_propagation_list is None:
qconfig_propagation_list = get_qconfig_propagation_list()
if custom_module_class_mapping is None:
custom_module_class_mapping = {}
# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
@ -139,13 +134,12 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
child.activation_post_process = get_activation_post_process(child.qconfig, device)
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
insert_activation_post_process(child)
elif needs_observation(child) and is_custom_module_class(type(child)):
observed_child = get_observed_custom_module_class(type(child)).from_float(child)
mark_observed_custom_module(observed_child, type(child))
elif needs_observation(child) and type(child) in custom_module_class_mapping:
observed_child = custom_module_class_mapping[type(child)].from_float(child)
setattr(module, name, observed_child)
insert_activation_post_process(observed_child)
else:
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device)
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
@ -180,7 +174,8 @@ def add_quant_dequant(module):
return module
def prepare(model, inplace=False, allow_list=None,
observer_non_leaf_module_list=None):
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
Quantization configuration should be assigned preemptively
@ -194,8 +189,21 @@ def prepare(model, inplace=False, allow_list=None,
inplace: carry out model transformations in-place, the original module is mutated
allow_list: list of quantizable modules
observer_non_leaf_module_list: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function:
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
if not inplace:
model = copy.deepcopy(model)
@ -210,7 +218,9 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")
add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list)
add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
return model
def _remove_qconfig(module):
@ -380,7 +390,9 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
convert(model, inplace=True)
return model
def convert(module, mapping=None, inplace=False, remove_qconfig=True):
def convert(
module, mapping=None, inplace=False, remove_qconfig=True,
convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True.
@ -392,17 +404,29 @@ def convert(module, mapping=None, inplace=False, remove_qconfig=True):
Modules
inplace: carry out model transformations in-place, the original module
is mutated
`convert_custom_config_dict`: custom configuration dictionary for convert function:
convert_custom_config_dict = {
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.convert")
if not inplace:
module = copy.deepcopy(module)
_convert(module, mapping, inplace=True)
_convert(
module, mapping, inplace=True,
convert_custom_config_dict=convert_custom_config_dict)
if remove_qconfig:
_remove_qconfig(module)
return module
def _convert(module, mapping=None, inplace=False):
def _convert(
module, mapping=None, inplace=False,
convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class
@ -417,6 +441,10 @@ def _convert(module, mapping=None, inplace=False):
"""
if mapping is None:
mapping = get_static_quant_module_mappings()
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
if not inplace:
module = copy.deepcopy(module)
reassign = {}
@ -440,16 +468,17 @@ def _convert(module, mapping=None, inplace=False):
# both swappable modules and observed custom modules are
# swapped as one unit
if type(mod) not in SWAPPABLE_MODULES and \
not is_observed_custom_module(mod):
_convert(mod, mapping, inplace=True)
reassign[name] = swap_module(mod, mapping)
type(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace
custom_module_class_mapping)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
for key, value in reassign.items():
module._modules[key] = value
return module
def swap_module(mod, mapping):
def swap_module(mod, mapping, custom_module_class_mapping):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
@ -464,8 +493,8 @@ def swap_module(mod, mapping):
# Always replace dequantstub with dequantize
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
swapped = False
if is_observed_custom_module(mod):
new_mod = get_quantized_custom_module_class(mod._FLOAT_MODULE).from_observed(mod)
if type(mod) in custom_module_class_mapping:
new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
swapped = True
elif type(mod) in mapping:
new_mod = mapping[type(mod)].from_float(mod)

View File

@ -13,10 +13,6 @@ from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic
from torch.quantization import (
is_custom_module_class,
is_observed_custom_module,
)
from torch.quantization.quantization_mappings import (
get_dynamic_quant_module_mappings,
get_qconfig_propagation_list,
@ -341,12 +337,15 @@ class QuantizationTestCase(TestCase):
self.assertTrue(hasattr(module, 'quant'))
self.assertTrue(hasattr(module, 'dequant'))
def checkObservers(self, module, propagate_qconfig_list=None):
def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
r"""Checks the module or module's leaf descendants
have observers in preperation for quantization
"""
if propagate_qconfig_list is None:
propagate_qconfig_list = get_qconfig_propagation_list()
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
# check if a module is a leaf module, ignoring activation_post_process attribute
def is_leaf_module(module):
@ -356,18 +355,18 @@ class QuantizationTestCase(TestCase):
submodule_name_count += 1
return submodule_name_count == 0
if (hasattr(module, 'qconfig') and module.qconfig is not None and
is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
and type(module) in propagate_qconfig_list) or \
is_custom_module_class(type(module)):
if hasattr(module, 'qconfig') and module.qconfig is not None and \
((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
and type(module) in propagate_qconfig_list) or
type(module) in float_to_observed_module_class_mapping.keys()):
self.assertTrue(hasattr(module, 'activation_post_process'),
'module: ' + str(type(module)) + ' do not have observer')
# we don't need to check observers for child modules of the
# qat modules
if type(module) not in get_qat_module_mappings().values() and \
not is_observed_custom_module(module):
type(module) not in float_to_observed_module_class_mapping.values():
for child in module.children():
self.checkObservers(child)
self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)
def checkQuantDequant(self, mod):
r"""Checks that mod has nn.Quantize and