mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73492 Before this PR, DBR quant reused the Eager mode quantization machinery to insert activation observers. This was done for speed of developing the prototype. A drawback of this is that the activation observers are not present in DBR's data structures and live on the modules instead. This PR refactors DBR quant to stop using Eager mode quantization observer insertion for activations, and instead create and track the activation observers in DBR's data structures. This has a couple of benefits: 1. activation observers are now created the same way in DBR for modules and functions 2. we can remove some technical debt due to fixing (1) 3. this will make it easier to support reference modules in a future PR The reason (3) is true is because the current design of reference modules assumes that the activation observer lives on the framework (like in FX graph mode quantization). This PR starts to adhere to that assumption. Test Plan: ``` python test/test_quantization.py -k DBR ``` Reviewed By: jerryzh168 Differential Revision: D34520758 Pulled By: vkuzo fbshipit-source-id: 2f6448dce021024cb2fa112d8691c94128c43123 (cherry picked from commit cfc1a0eaf6579cea2c710c1c2b4c86d28ee799eb)
144 lines
5.8 KiB
Python
144 lines
5.8 KiB
Python
import torch
|
|
|
|
from ._dbr.auto_trace import add_auto_observation, add_auto_convert
|
|
from ._dbr.fusion import get_module_fusion_fqns
|
|
from ._dbr.qconfig_dict_utils import normalize_object_types
|
|
|
|
from .qconfig_dict_utils import (
|
|
get_flattened_qconfig_dict,
|
|
convert_dict_to_ordered_dict,
|
|
)
|
|
from torch.ao.quantization.quantization_mappings import (
|
|
get_default_static_quant_module_mappings,
|
|
get_default_dynamic_quant_module_mappings,
|
|
)
|
|
from ._dbr.module_swap_utils import _swap_child_modules
|
|
|
|
|
|
def prepare(model, qconfig_dict, example_inputs, inplace=False, allow_list=None,
|
|
observer_non_leaf_module_list=None,
|
|
prepare_custom_config_dict=None,
|
|
fuse_modules=True):
|
|
r"""A wrapper around `torch.quantization.prepare` which prepares the
|
|
model for quantization using dynamic tracing.
|
|
|
|
Requires `qconfig_dict` (same format as prepare_fx) to specify the
|
|
quantization settings. Not all functionality is supported yet.
|
|
|
|
Requires `example_inputs` to build
|
|
the graph before calibration or quantization aware training can proceed.
|
|
|
|
Supported `prepare_custom_config_dict` keys:
|
|
* `non_traceable_module_class` - same meaning as in prepare_fx
|
|
* `output_dtypes` - expected dtypes of model outputs, must match actual
|
|
output structure.
|
|
|
|
TODO(future PR): better docblock
|
|
"""
|
|
assert example_inputs is not None, 'example_inputs must be specified'
|
|
|
|
if prepare_custom_config_dict is None:
|
|
prepare_custom_config_dict = {}
|
|
|
|
for qconfig_dict_option in ('module_name_regex', 'module_name_object_type_order'):
|
|
if qconfig_dict_option in qconfig_dict:
|
|
assert len(qconfig_dict[qconfig_dict_option]) == 0, \
|
|
f'{qconfig_dict_option} option of qconfig_dict is not ' + \
|
|
'implemented yet in define-by-run quantization'
|
|
|
|
normalize_object_types(qconfig_dict)
|
|
convert_dict_to_ordered_dict(qconfig_dict)
|
|
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
|
|
torch.quantization.propagate_qconfig_(model, flattened_qconfig_dict)
|
|
|
|
# if parts of the model are non traceable, delete qconfig from
|
|
# them so they do not get swapped
|
|
non_traceable_module_class = \
|
|
prepare_custom_config_dict.get('non_traceable_module_class', [])
|
|
for name, child in model.named_modules():
|
|
for target_cls in non_traceable_module_class:
|
|
if isinstance(child, target_cls):
|
|
for _, child_child in child.named_modules():
|
|
child_child.qconfig = None
|
|
|
|
# TODO(future PR): QAT support
|
|
|
|
if fuse_modules:
|
|
# automatically fuse modules
|
|
old_class = model.__class__
|
|
model = add_auto_observation(
|
|
model, qconfig_dict, example_inputs,
|
|
prepare_custom_config_dict=prepare_custom_config_dict)
|
|
module_fusion_fqns = get_module_fusion_fqns(model)
|
|
if len(module_fusion_fqns):
|
|
model = torch.quantization.fuse_modules(model, module_fusion_fqns)
|
|
|
|
# Since we are reusing the auto_trace machinery to find fusion
|
|
# FQNs, we need to do some surgery to get qconfigs on modules
|
|
# after module fusion to be correct.
|
|
for _, child in model.named_modules():
|
|
if isinstance(child, torch.nn.intrinsic._FusedModule):
|
|
if hasattr(child[0], 'qconfig'):
|
|
child.qconfig = child[0].qconfig
|
|
|
|
# delete all the DBR state from the model, so add_auto_observation
|
|
# can start from a clean slate
|
|
parents_to_delete_auto_quant_state = []
|
|
for k, v in model.named_modules():
|
|
if hasattr(v, '_auto_quant_state'):
|
|
parents_to_delete_auto_quant_state.append(v)
|
|
for v in parents_to_delete_auto_quant_state:
|
|
del v._auto_quant_state
|
|
|
|
del model._fqn_to_auto_quant_state_map
|
|
|
|
for p in model.parameters():
|
|
if hasattr(p, '_qtensor_info'):
|
|
del p._qtensor_info
|
|
for b in model.buffers():
|
|
if hasattr(b, '_qtensor_info'):
|
|
del b._qtensor_info
|
|
|
|
# the model hierarchy might have changed during fusion, so we
|
|
# have to delete the cached module hook types
|
|
for k, v in model.named_modules():
|
|
if hasattr(v, '_auto_quant_module_hook_type'):
|
|
del v._auto_quant_module_hook_type
|
|
|
|
model.__class__ = old_class
|
|
|
|
# Automatically assign qconfigs for modules where the defaults do not
|
|
# work.
|
|
# TODO(future PR): clean this up and align with other APIs
|
|
for name, child in model.named_modules():
|
|
if isinstance(child, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
|
|
# pass
|
|
# child.qconfig = torch.quantization.float_qparams_weight_only_qconfig
|
|
# uncomment below to unbreak attention_is_all_you_need
|
|
# TODO write up issue, maybe fix
|
|
child.qconfig = None # type: ignore[assignment]
|
|
elif isinstance(child, torch.nn.LSTM):
|
|
# TODO: fix LSTM handling in eager mode static quant and remove this
|
|
qconfig_dict['object_type'][torch.nn.LSTM] = None
|
|
|
|
# TODO(future PR): do the QAT module swap
|
|
|
|
assert not inplace
|
|
model = add_auto_observation(
|
|
model, qconfig_dict, example_inputs,
|
|
prepare_custom_config_dict=prepare_custom_config_dict)
|
|
return model
|
|
|
|
def convert(model: torch.nn.Module) -> torch.nn.Module:
|
|
r"""Converts a prepared DBR quantization model to a quantized form.
|
|
|
|
TODO(future PR): better docblock
|
|
"""
|
|
static_mappings = get_default_static_quant_module_mappings()
|
|
dynamic_mappings = get_default_dynamic_quant_module_mappings()
|
|
# swap the modules
|
|
_swap_child_modules(model, static_mappings, dynamic_mappings)
|
|
# add dynamic handling for quants/dequants, functions and methods
|
|
model = add_auto_convert(model)
|
|
return model
|