pytorch/torch/ao/quantization/qconfig_mapping_utils.py
Andrew Or c7b4eec233 [Quant][fx][bc-breaking] Replace qconfig_dict with a config object (#78452)
**Summary:** Previously, FX graph mode quantization configurations
were specified through a dictionary of qconfigs. However, this
API was not in line with other core APIs in PyTorch. This commit
replaces this dictionary with a config object that users will
create and pass to prepare and convert. This leads to better
type safety and better user experience in notebook settings
due to improved auto completion.

The new API is as follows:

```
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx

qconfig_mapping = QConfigMapping()
    .set_global(qconfig)
    .set_object_type(torch.nn.Linear, qconfig)
    .set_module_name_regex("foo.*bar", qconfig)
    .set_module_name("mod", qconfig)

prepare_fx(model, qconfig_mapping)
```

For backwards compatibility, `prepare_fx`, `prepare_qat_fx`,
and `convert_fx` will continue to accept qconfig_dicts, which
will be converted to QuantizationConfigs internally.

Note that this commit does not modify existing tests to use the
new API; they will continue to pass in qconfig_dict as before,
which still works but triggers a deprecation warning. This will
be handled in a future commit.

**Test Plan:**
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

Differential Revision: D36747998

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78452
Approved by: https://github.com/jerryzh168
2022-05-30 18:30:07 +00:00

111 lines
3.8 KiB
Python

import re
from typing import Dict, Callable, Union
from .utils import (
get_combined_dict,
_parent_name,
)
from .quantization_mappings import (
get_default_qat_module_mappings,
)
from .qconfig import QConfigAny
from .qconfig_mapping import QConfigMapping
# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"get_flattened_qconfig_dict",
"get_object_type_qconfig",
"get_module_name_qconfig",
"get_module_name_regex_qconfig",
"maybe_adjust_qconfig_for_module_type_or_name",
"update_qconfig_for_qat",
]
def get_object_type_qconfig(
qconfig_mapping: QConfigMapping,
object_type: Union[Callable, str],
fallback_qconfig: QConfigAny) -> QConfigAny:
return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
def get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig
def get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_mapping.module_name_qconfigs:
return qconfig_mapping.module_name_qconfigs[module_name]
else:
parent, _ = _parent_name(module_name)
return get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
def maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
module_type_qconfig = get_object_type_qconfig(
qconfig_mapping, module_type, global_qconfig)
module_name_regex_qconfig = get_module_name_regex_qconfig(
qconfig_mapping, module_name, module_type_qconfig)
module_name_qconfig = get_module_name_qconfig(
qconfig_mapping, module_name, module_name_regex_qconfig)
return module_name_qconfig
def get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
""" flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
propagate_qconfig_ function.
"module_name_regex" is ignored for now since it's not supported
in propagate_qconfig_, but it can be fixed later.
For example:
Input: {
"": qconfig,
"object_type": [
(torch.add, qconfig)
],
"module_name": [
("conv", qconfig)
]
}
Output: {
"": qconfig,
torch.add: qconfig,
"conv": qconfig
}
"""
flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig}
for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
flattened[obj] = qconfig
for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
flattened[obj] = qconfig
return flattened
def update_qconfig_for_qat(
qconfig_mapping: QConfigMapping,
additional_qat_module_mapping: Dict[Callable, Callable]):
"""
Update the qconfig_dict to account for module swaps during QAT.
During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
"""
all_qat_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
object_type_dict = qconfig_mapping.object_type_qconfigs
new_object_type_dict = object_type_dict.copy()
for k, v in new_object_type_dict.items():
if k in all_qat_mappings:
object_type_dict[all_qat_mappings[k]] = v