mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][eagermode] Move custom_module registration to prepare/convert_custom_config_dict (#46293)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46293 Test Plan: Imported from OSS Reviewed By: raghuramank100 Differential Revision: D24290811 fbshipit-source-id: 7d2aee98e1946c2a4268efb94443f1e5daaa793e
This commit is contained in:
parent
2ffb768607
commit
3ad797c937
|
|
@ -23,8 +23,6 @@ from torch.quantization import (
|
|||
per_channel_dynamic_qconfig,
|
||||
float16_dynamic_qconfig,
|
||||
float_qparams_dynamic_qconfig,
|
||||
register_observed_custom_module_mapping,
|
||||
register_quantized_custom_module_mapping,
|
||||
PerChannelMinMaxObserver,
|
||||
QConfigDynamic,
|
||||
default_dynamic_quant_observer
|
||||
|
|
@ -627,9 +625,6 @@ class TestPostTrainingStatic(QuantizationTestCase):
|
|||
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
|
||||
return quantized
|
||||
|
||||
register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
|
||||
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -670,14 +665,28 @@ class TestPostTrainingStatic(QuantizationTestCase):
|
|||
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())
|
||||
|
||||
original_m.qconfig = default_qconfig
|
||||
m = prepare(original_m)
|
||||
self.checkObservers(m)
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
CustomModule: ObservedCustomModule
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
ObservedCustomModule: QuantizedCustomModule
|
||||
}
|
||||
}
|
||||
m = prepare(
|
||||
original_m,
|
||||
prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
self.checkObservers(m, None, prepare_custom_config_dict)
|
||||
# calibration
|
||||
m(data)
|
||||
# all activation observers are inserted in the top level module
|
||||
|
||||
# check converted/quantized model
|
||||
m = convert(m)
|
||||
m = convert(
|
||||
m,
|
||||
convert_custom_config_dict=convert_custom_config_dict)
|
||||
# check if the module is properly quantized
|
||||
self.assertEqual(type(m.quant), nnq.Quantize)
|
||||
self.assertEqual(type(m.conv), nnq.Conv2d)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from .quantize_jit import *
|
|||
from .quantize_fx import *
|
||||
from .quantization_mappings import *
|
||||
from .fuser_method_mappings import *
|
||||
from .custom_module_class_mappings import *
|
||||
|
||||
def default_eval_fn(model, calib_data):
|
||||
r"""
|
||||
|
|
@ -41,12 +40,6 @@ _all__ = [
|
|||
'get_compare_output_module_list',
|
||||
'register_quantized_operator_mapping', 'get_quantized_operator',
|
||||
'register_fuser_method', 'get_fuser_method',
|
||||
'register_observed_custom_module_mapping',
|
||||
'get_observed_custom_module_class',
|
||||
'register_quantized_custom_mdoule_mapping',
|
||||
'get_quantized_custom_module_class',
|
||||
'is_custom_module_class',
|
||||
'is_observed_custom_module',
|
||||
# Sub functions for `prepare` and `swap_module`
|
||||
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
|
||||
'default_eval_fn', 'get_observer_dict',
|
||||
|
|
|
|||
|
|
@ -1,75 +0,0 @@
|
|||
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
|
||||
|
||||
def register_observed_custom_module_mapping(float_custom_module_class, observed_custom_module_class):
|
||||
""" Register a mapping from `float_custom_module_class` to
|
||||
`observed_custom_module_class`
|
||||
`observed_custom_module_class` will have a `from_float` classmethod,
|
||||
which will return an observed custom module instance given
|
||||
a float custom module instance.
|
||||
This will be used in prepare step of post training static quantization or
|
||||
quantization aware training
|
||||
"""
|
||||
assert hasattr(observed_custom_module_class, 'from_float'), 'from_float must be' + \
|
||||
' defined in observed custom module class'
|
||||
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
|
||||
observed_custom_module_class
|
||||
|
||||
def get_observed_custom_module_class(float_custom_module_class):
|
||||
""" Get the corresponding observed module class for a given
|
||||
float custom module.
|
||||
"""
|
||||
observed_custom_module_class = \
|
||||
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
|
||||
assert observed_custom_module_class is not None, \
|
||||
'Float Custom module class {}'.format(float_custom_module_class) + \
|
||||
' does not have a corresponding observed module class'
|
||||
return observed_custom_module_class
|
||||
|
||||
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
|
||||
|
||||
def register_quantized_custom_module_mapping(float_custom_module_class, quantized_custom_module_class):
|
||||
""" Register a mapping from `float_custom_module_class` to `quantized_custom_module_class`
|
||||
A quantized custom module class should accept quantized input and
|
||||
return quantized output. (we can relax this condition in the
|
||||
future if there is a need)
|
||||
`quantized_custom_module_class` will have a `from_observed` classmethod,
|
||||
which will return an quantized custom module instance given
|
||||
a observed custom module instance.
|
||||
This will be used in prepare step of post training static quantization or
|
||||
quantization aware training
|
||||
"""
|
||||
assert hasattr(quantized_custom_module_class, 'from_observed'), 'from_observed' + \
|
||||
' must be defined in quantized custom module class'
|
||||
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
|
||||
quantized_custom_module_class
|
||||
|
||||
def get_quantized_custom_module_class(float_custom_module_class):
|
||||
""" Get the corresponding quantized module class for a given
|
||||
float custom module.
|
||||
"""
|
||||
quantized_custom_module_class = \
|
||||
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
|
||||
assert quantized_custom_module_class is not None, \
|
||||
'Float Custom module class {}'.format(float_custom_module_class) + \
|
||||
' does not have a corresponding quantized module class'
|
||||
return quantized_custom_module_class
|
||||
|
||||
def is_custom_module_class(module_class):
|
||||
""" Check if a given module class is a custom module class
|
||||
"""
|
||||
return module_class in OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS and \
|
||||
module_class in QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS
|
||||
|
||||
def mark_observed_custom_module(module, custom_module_class):
|
||||
""" Mark a module as observed custom module, so that
|
||||
it can be identified during convert step
|
||||
"""
|
||||
module._is_observed_custom_module = True
|
||||
module._FLOAT_MODULE = custom_module_class
|
||||
|
||||
def is_observed_custom_module(module):
|
||||
""" Check if a module is marked as observed custom module
|
||||
or not
|
||||
"""
|
||||
return hasattr(module, '_is_observed_custom_module') and \
|
||||
module._is_observed_custom_module
|
||||
|
|
@ -14,14 +14,6 @@ from .quantization_mappings import (get_dynamic_quant_module_mappings,
|
|||
get_qat_module_mappings,
|
||||
get_qconfig_propagation_list)
|
||||
|
||||
from .custom_module_class_mappings import (
|
||||
is_custom_module_class,
|
||||
get_observed_custom_module_class,
|
||||
get_quantized_custom_module_class,
|
||||
mark_observed_custom_module,
|
||||
is_observed_custom_module,
|
||||
)
|
||||
|
||||
from .stubs import DeQuantStub, QuantWrapper
|
||||
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
|
||||
|
||||
|
|
@ -86,7 +78,7 @@ def register_activation_post_process_hook(module):
|
|||
'Expect activation_post_process attribut already attached to the module'
|
||||
return module.register_forward_hook(_observer_forward_hook)
|
||||
|
||||
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None):
|
||||
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
|
||||
r"""Add observer for the leaf child of the module.
|
||||
|
||||
This function insert observer module to all leaf child module that
|
||||
|
|
@ -103,6 +95,9 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
|||
if qconfig_propagation_list is None:
|
||||
qconfig_propagation_list = get_qconfig_propagation_list()
|
||||
|
||||
if custom_module_class_mapping is None:
|
||||
custom_module_class_mapping = {}
|
||||
|
||||
# respect device affinity when adding observers
|
||||
if device is None:
|
||||
devices = get_unique_devices_(module)
|
||||
|
|
@ -139,13 +134,12 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
|||
child.activation_post_process = get_activation_post_process(child.qconfig, device)
|
||||
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
|
||||
insert_activation_post_process(child)
|
||||
elif needs_observation(child) and is_custom_module_class(type(child)):
|
||||
observed_child = get_observed_custom_module_class(type(child)).from_float(child)
|
||||
mark_observed_custom_module(observed_child, type(child))
|
||||
elif needs_observation(child) and type(child) in custom_module_class_mapping:
|
||||
observed_child = custom_module_class_mapping[type(child)].from_float(child)
|
||||
setattr(module, name, observed_child)
|
||||
insert_activation_post_process(observed_child)
|
||||
else:
|
||||
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device)
|
||||
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
|
||||
|
||||
# Insert observers only for leaf nodes, note that this observer is for
|
||||
# the output of the module, for input QuantStub will observe them
|
||||
|
|
@ -180,7 +174,8 @@ def add_quant_dequant(module):
|
|||
return module
|
||||
|
||||
def prepare(model, inplace=False, allow_list=None,
|
||||
observer_non_leaf_module_list=None):
|
||||
observer_non_leaf_module_list=None,
|
||||
prepare_custom_config_dict=None):
|
||||
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
|
||||
|
||||
Quantization configuration should be assigned preemptively
|
||||
|
|
@ -194,8 +189,21 @@ def prepare(model, inplace=False, allow_list=None,
|
|||
inplace: carry out model transformations in-place, the original module is mutated
|
||||
allow_list: list of quantizable modules
|
||||
observer_non_leaf_module_list: list of non-leaf modules we want to add observer
|
||||
`prepare_custom_config_dict`: customization configuration dictionary for prepare function:
|
||||
# user will manually define the corresponding observed
|
||||
# module class which has a from_float class method that converts
|
||||
# float custom module to observed custom module
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
CustomModule: ObservedCustomModule
|
||||
}
|
||||
}
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
|
||||
if prepare_custom_config_dict is None:
|
||||
prepare_custom_config_dict = {}
|
||||
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
|
||||
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
|
|
@ -210,7 +218,9 @@ def prepare(model, inplace=False, allow_list=None,
|
|||
"passed correct configuration through `qconfig_dict` or "
|
||||
"by assigning the `.qconfig` attribute directly on submodules")
|
||||
|
||||
add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list)
|
||||
add_observer_(
|
||||
model, qconfig_propagation_list, observer_non_leaf_module_list,
|
||||
custom_module_class_mapping=custom_module_class_mapping)
|
||||
return model
|
||||
|
||||
def _remove_qconfig(module):
|
||||
|
|
@ -380,7 +390,9 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
|
|||
convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
def convert(module, mapping=None, inplace=False, remove_qconfig=True):
|
||||
def convert(
|
||||
module, mapping=None, inplace=False, remove_qconfig=True,
|
||||
convert_custom_config_dict=None):
|
||||
r"""Converts submodules in input module to a different module according to `mapping`
|
||||
by calling `from_float` method on the target module class. And remove qconfig at the
|
||||
end if remove_qconfig is set to True.
|
||||
|
|
@ -392,17 +404,29 @@ def convert(module, mapping=None, inplace=False, remove_qconfig=True):
|
|||
Modules
|
||||
inplace: carry out model transformations in-place, the original module
|
||||
is mutated
|
||||
|
||||
`convert_custom_config_dict`: custom configuration dictionary for convert function:
|
||||
convert_custom_config_dict = {
|
||||
# user will manually define the corresponding quantized
|
||||
# module class which has a from_observed class method that converts
|
||||
# observed custom module to quantized custom module
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
ObservedCustomModule: QuantizedCustomModule
|
||||
}
|
||||
}
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize.convert")
|
||||
if not inplace:
|
||||
module = copy.deepcopy(module)
|
||||
_convert(module, mapping, inplace=True)
|
||||
_convert(
|
||||
module, mapping, inplace=True,
|
||||
convert_custom_config_dict=convert_custom_config_dict)
|
||||
if remove_qconfig:
|
||||
_remove_qconfig(module)
|
||||
return module
|
||||
|
||||
def _convert(module, mapping=None, inplace=False):
|
||||
def _convert(
|
||||
module, mapping=None, inplace=False,
|
||||
convert_custom_config_dict=None):
|
||||
r"""Converts submodules in input module to a different module according to `mapping`
|
||||
by calling `from_float` method on the target module class
|
||||
|
||||
|
|
@ -417,6 +441,10 @@ def _convert(module, mapping=None, inplace=False):
|
|||
"""
|
||||
if mapping is None:
|
||||
mapping = get_static_quant_module_mappings()
|
||||
if convert_custom_config_dict is None:
|
||||
convert_custom_config_dict = {}
|
||||
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
|
||||
|
||||
if not inplace:
|
||||
module = copy.deepcopy(module)
|
||||
reassign = {}
|
||||
|
|
@ -440,16 +468,17 @@ def _convert(module, mapping=None, inplace=False):
|
|||
# both swappable modules and observed custom modules are
|
||||
# swapped as one unit
|
||||
if type(mod) not in SWAPPABLE_MODULES and \
|
||||
not is_observed_custom_module(mod):
|
||||
_convert(mod, mapping, inplace=True)
|
||||
reassign[name] = swap_module(mod, mapping)
|
||||
type(mod) not in custom_module_class_mapping:
|
||||
_convert(mod, mapping, True, # inplace
|
||||
custom_module_class_mapping)
|
||||
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
|
||||
|
||||
for key, value in reassign.items():
|
||||
module._modules[key] = value
|
||||
|
||||
return module
|
||||
|
||||
def swap_module(mod, mapping):
|
||||
def swap_module(mod, mapping, custom_module_class_mapping):
|
||||
r"""Swaps the module if it has a quantized counterpart and it has an
|
||||
`observer` attached.
|
||||
|
||||
|
|
@ -464,8 +493,8 @@ def swap_module(mod, mapping):
|
|||
# Always replace dequantstub with dequantize
|
||||
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
|
||||
swapped = False
|
||||
if is_observed_custom_module(mod):
|
||||
new_mod = get_quantized_custom_module_class(mod._FLOAT_MODULE).from_observed(mod)
|
||||
if type(mod) in custom_module_class_mapping:
|
||||
new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
|
||||
swapped = True
|
||||
elif type(mod) in mapping:
|
||||
new_mod = mapping[type(mod)].from_float(mod)
|
||||
|
|
|
|||
|
|
@ -13,10 +13,6 @@ from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
|
|||
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
|
||||
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
|
||||
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic
|
||||
from torch.quantization import (
|
||||
is_custom_module_class,
|
||||
is_observed_custom_module,
|
||||
)
|
||||
from torch.quantization.quantization_mappings import (
|
||||
get_dynamic_quant_module_mappings,
|
||||
get_qconfig_propagation_list,
|
||||
|
|
@ -341,12 +337,15 @@ class QuantizationTestCase(TestCase):
|
|||
self.assertTrue(hasattr(module, 'quant'))
|
||||
self.assertTrue(hasattr(module, 'dequant'))
|
||||
|
||||
def checkObservers(self, module, propagate_qconfig_list=None):
|
||||
def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
|
||||
r"""Checks the module or module's leaf descendants
|
||||
have observers in preperation for quantization
|
||||
"""
|
||||
if propagate_qconfig_list is None:
|
||||
propagate_qconfig_list = get_qconfig_propagation_list()
|
||||
if prepare_custom_config_dict is None:
|
||||
prepare_custom_config_dict = {}
|
||||
float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
|
||||
|
||||
# check if a module is a leaf module, ignoring activation_post_process attribute
|
||||
def is_leaf_module(module):
|
||||
|
|
@ -356,18 +355,18 @@ class QuantizationTestCase(TestCase):
|
|||
submodule_name_count += 1
|
||||
return submodule_name_count == 0
|
||||
|
||||
if (hasattr(module, 'qconfig') and module.qconfig is not None and
|
||||
is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
|
||||
and type(module) in propagate_qconfig_list) or \
|
||||
is_custom_module_class(type(module)):
|
||||
if hasattr(module, 'qconfig') and module.qconfig is not None and \
|
||||
((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
|
||||
and type(module) in propagate_qconfig_list) or
|
||||
type(module) in float_to_observed_module_class_mapping.keys()):
|
||||
self.assertTrue(hasattr(module, 'activation_post_process'),
|
||||
'module: ' + str(type(module)) + ' do not have observer')
|
||||
# we don't need to check observers for child modules of the
|
||||
# qat modules
|
||||
if type(module) not in get_qat_module_mappings().values() and \
|
||||
not is_observed_custom_module(module):
|
||||
type(module) not in float_to_observed_module_class_mapping.values():
|
||||
for child in module.children():
|
||||
self.checkObservers(child)
|
||||
self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)
|
||||
|
||||
def checkQuantDequant(self, mod):
|
||||
r"""Checks that mod has nn.Quantize and
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user