pytorch/torch/ao/quantization/_quantize_dbr.py
Vasiliy Kuznetsov 5787a36e30 dbr quant: insert activation obs explicitly, instead of relying on hooks (#73492)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73492

Before this PR, DBR quant reused the Eager mode quantization machinery
to insert activation observers. This was done for speed of developing the
prototype. A drawback of this is that the activation observers are not
present in DBR's data structures and live on the modules instead.

This PR refactors DBR quant to stop using Eager mode quantization
observer insertion for activations, and instead create and track the
activation observers in DBR's data structures. This has a couple of benefits:
1. activation observers are now created the same way in DBR for modules and functions
2. we can remove some technical debt due to fixing (1)
3. this will make it easier to support reference modules in a future PR

The reason (3) is true is because the current design of reference modules
assumes that the activation observer lives on the framework (like in FX
graph mode quantization). This PR starts to adhere to that assumption.

Test Plan:
```
python test/test_quantization.py -k DBR
```

Reviewed By: jerryzh168

Differential Revision: D34520758

Pulled By: vkuzo

fbshipit-source-id: 2f6448dce021024cb2fa112d8691c94128c43123
(cherry picked from commit cfc1a0eaf6579cea2c710c1c2b4c86d28ee799eb)
2022-03-04 17:35:31 +00:00

144 lines
5.8 KiB
Python

import torch
from ._dbr.auto_trace import add_auto_observation, add_auto_convert
from ._dbr.fusion import get_module_fusion_fqns
from ._dbr.qconfig_dict_utils import normalize_object_types
from .qconfig_dict_utils import (
get_flattened_qconfig_dict,
convert_dict_to_ordered_dict,
)
from torch.ao.quantization.quantization_mappings import (
get_default_static_quant_module_mappings,
get_default_dynamic_quant_module_mappings,
)
from ._dbr.module_swap_utils import _swap_child_modules
def prepare(model, qconfig_dict, example_inputs, inplace=False, allow_list=None,
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None,
fuse_modules=True):
r"""A wrapper around `torch.quantization.prepare` which prepares the
model for quantization using dynamic tracing.
Requires `qconfig_dict` (same format as prepare_fx) to specify the
quantization settings. Not all functionality is supported yet.
Requires `example_inputs` to build
the graph before calibration or quantization aware training can proceed.
Supported `prepare_custom_config_dict` keys:
* `non_traceable_module_class` - same meaning as in prepare_fx
* `output_dtypes` - expected dtypes of model outputs, must match actual
output structure.
TODO(future PR): better docblock
"""
assert example_inputs is not None, 'example_inputs must be specified'
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
for qconfig_dict_option in ('module_name_regex', 'module_name_object_type_order'):
if qconfig_dict_option in qconfig_dict:
assert len(qconfig_dict[qconfig_dict_option]) == 0, \
f'{qconfig_dict_option} option of qconfig_dict is not ' + \
'implemented yet in define-by-run quantization'
normalize_object_types(qconfig_dict)
convert_dict_to_ordered_dict(qconfig_dict)
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
torch.quantization.propagate_qconfig_(model, flattened_qconfig_dict)
# if parts of the model are non traceable, delete qconfig from
# them so they do not get swapped
non_traceable_module_class = \
prepare_custom_config_dict.get('non_traceable_module_class', [])
for name, child in model.named_modules():
for target_cls in non_traceable_module_class:
if isinstance(child, target_cls):
for _, child_child in child.named_modules():
child_child.qconfig = None
# TODO(future PR): QAT support
if fuse_modules:
# automatically fuse modules
old_class = model.__class__
model = add_auto_observation(
model, qconfig_dict, example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
module_fusion_fqns = get_module_fusion_fqns(model)
if len(module_fusion_fqns):
model = torch.quantization.fuse_modules(model, module_fusion_fqns)
# Since we are reusing the auto_trace machinery to find fusion
# FQNs, we need to do some surgery to get qconfigs on modules
# after module fusion to be correct.
for _, child in model.named_modules():
if isinstance(child, torch.nn.intrinsic._FusedModule):
if hasattr(child[0], 'qconfig'):
child.qconfig = child[0].qconfig
# delete all the DBR state from the model, so add_auto_observation
# can start from a clean slate
parents_to_delete_auto_quant_state = []
for k, v in model.named_modules():
if hasattr(v, '_auto_quant_state'):
parents_to_delete_auto_quant_state.append(v)
for v in parents_to_delete_auto_quant_state:
del v._auto_quant_state
del model._fqn_to_auto_quant_state_map
for p in model.parameters():
if hasattr(p, '_qtensor_info'):
del p._qtensor_info
for b in model.buffers():
if hasattr(b, '_qtensor_info'):
del b._qtensor_info
# the model hierarchy might have changed during fusion, so we
# have to delete the cached module hook types
for k, v in model.named_modules():
if hasattr(v, '_auto_quant_module_hook_type'):
del v._auto_quant_module_hook_type
model.__class__ = old_class
# Automatically assign qconfigs for modules where the defaults do not
# work.
# TODO(future PR): clean this up and align with other APIs
for name, child in model.named_modules():
if isinstance(child, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
# pass
# child.qconfig = torch.quantization.float_qparams_weight_only_qconfig
# uncomment below to unbreak attention_is_all_you_need
# TODO write up issue, maybe fix
child.qconfig = None # type: ignore[assignment]
elif isinstance(child, torch.nn.LSTM):
# TODO: fix LSTM handling in eager mode static quant and remove this
qconfig_dict['object_type'][torch.nn.LSTM] = None
# TODO(future PR): do the QAT module swap
assert not inplace
model = add_auto_observation(
model, qconfig_dict, example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
return model
def convert(model: torch.nn.Module) -> torch.nn.Module:
r"""Converts a prepared DBR quantization model to a quantized form.
TODO(future PR): better docblock
"""
static_mappings = get_default_static_quant_module_mappings()
dynamic_mappings = get_default_dynamic_quant_module_mappings()
# swap the modules
_swap_child_modules(model, static_mappings, dynamic_mappings)
# add dynamic handling for quants/dequants, functions and methods
model = add_auto_convert(model)
return model