[Quant][fx] Hide equalization_config from prepare APIs (#80164)

Summary: This PR hides the equalization_config argument from
prepare_fx. This is a private API that we do not wish to expose
to users and have to maintain backward compatibility for.

Test Plan:
python test/test_quantization.py TestEqualizeFx

Reviewers: jerryzh168

Subscribers: jerryzh168

Differential Revision: [D37394353](https://our.internmc.facebook.com/intern/diff/D37394353)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80164
Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or 2022-06-23 15:24:54 -07:00 committed by PyTorch MergeBot
parent 8a45ef23f5
commit 8aedd8fb25
3 changed files with 29 additions and 29 deletions

View File

@ -279,7 +279,7 @@ class TestEqualizeFx(QuantizationTestCase):
m,
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
def test_input_weight_equalization_branching(self):
@ -313,7 +313,7 @@ class TestEqualizeFx(QuantizationTestCase):
example_inputs = (torch.rand(1, 5),)
prepared = prepare_fx(
m, specific_qconfig_dict, example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)
# Tests that we will add an equalization observer because there is only
@ -337,7 +337,7 @@ class TestEqualizeFx(QuantizationTestCase):
example_inputs = (torch.randn(1, 5),)
prepared = prepare_fx(
m, specific_qconfig_dict, example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)
@skipIfNoFBGEMM
@ -369,7 +369,7 @@ class TestEqualizeFx(QuantizationTestCase):
copy.deepcopy(m),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict
_equalization_config=default_equalization_qconfig_dict
)
output = prepared(x)
@ -379,7 +379,7 @@ class TestEqualizeFx(QuantizationTestCase):
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
prepared(x)
convert_fx(prepared) # Check if compile
self.assertEqual(output, convert_ref_output)
@ -431,7 +431,7 @@ class TestEqualizeFx(QuantizationTestCase):
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
prepared(*example_inputs)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
@ -484,7 +484,7 @@ class TestEqualizeFx(QuantizationTestCase):
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
@ -544,7 +544,7 @@ class TestEqualizeFx(QuantizationTestCase):
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
@ -783,7 +783,7 @@ class TestEqualizeFx(QuantizationTestCase):
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict)
_equalization_config=default_equalization_qconfig_dict)
equalized_quantized_model = convert_fx(prepared)
# Check the order of nodes in the graph
@ -808,7 +808,7 @@ class TestEqualizeFx(QuantizationTestCase):
copy.deepcopy(m),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config={})
_equalization_config={})
prepared(x)
quantized = convert_fx(prepared) # Check if compile
quantized_output = quantized(x)
@ -818,7 +818,7 @@ class TestEqualizeFx(QuantizationTestCase):
copy.deepcopy(m),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=default_equalization_qconfig_dict
_equalization_config=default_equalization_qconfig_dict
)
prepared(x)
equalized_and_quantized = convert_fx(prepared) # Check if compile
@ -876,7 +876,7 @@ class TestEqualizeFx(QuantizationTestCase):
copy.deepcopy(float_model),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_config=selective_equalization_qconfig_dict,
_equalization_config=selective_equalization_qconfig_dict,
)
prepared_model(x)
equalized_model = convert_fx(prepared_model)

View File

@ -1387,7 +1387,7 @@ def prepare(
node_name_to_scope: Dict[str, Tuple[str, type]],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
_equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False) -> ObservedGraphModule:
""" standalone_module means it a submodule that is not inlined in
@ -1412,8 +1412,8 @@ def prepare(
"""
if prepare_custom_config is None:
prepare_custom_config = PrepareCustomConfig()
if equalization_config is None:
equalization_config = QConfigMapping()
if _equalization_config is None:
_equalization_config = QConfigMapping()
if isinstance(qconfig_mapping, Dict):
warnings.warn(
@ -1421,11 +1421,11 @@ def prepare(
"in a future version. Please pass in a QConfigMapping instead.")
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
if isinstance(equalization_config, Dict):
if isinstance(_equalization_config, Dict):
warnings.warn(
"Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
"be supported in a future version. Please pass in a QConfigMapping instead.")
equalization_config = QConfigMapping.from_dict(equalization_config)
_equalization_config = QConfigMapping.from_dict(_equalization_config)
if isinstance(prepare_custom_config, Dict):
warnings.warn(
@ -1434,9 +1434,9 @@ def prepare(
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
assert(isinstance(qconfig_mapping, QConfigMapping))
assert(isinstance(equalization_config, QConfigMapping))
assert(isinstance(_equalization_config, QConfigMapping))
qconfig_mapping = copy.deepcopy(qconfig_mapping)
equalization_config = copy.deepcopy(equalization_config)
_equalization_config = copy.deepcopy(_equalization_config)
# mapping from a tuple of nodes in reverse order to uninitialized
# QuantizeHandler subclass. For example,
@ -1477,7 +1477,7 @@ def prepare(
get_fusion_pattern_to_root_node_getter(backend_config_dict)
update_qconfig_for_fusion(model, qconfig_mapping)
update_qconfig_for_fusion(model, equalization_config)
update_qconfig_for_fusion(model, _equalization_config)
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
@ -1498,7 +1498,7 @@ def prepare(
# fill qconfig_map, a map from node name to qconfig, used in find_matches
equalization_qconfig_map = generate_qconfig_map(
model, modules, model.graph, equalization_config, node_name_to_scope)
model, modules, model.graph, _equalization_config, node_name_to_scope)
qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_mapping, node_name_to_scope)
# match the patterns that will get quantized

View File

@ -181,13 +181,13 @@ def _prepare_fx(
is_qat: bool,
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False,
) -> ObservedGraphModule:
r""" Internal helper function for prepare_fx
Args:
`model`, `qconfig_mapping`, `prepare_custom_config`, `equalization_config`:
`model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
see docs for :func:`~torch.ao.quantization.prepare_fx`
`is_standalone_module`: a boolean flag indicates whether we are
quantizing a standalone module or not, a standalone module
@ -198,8 +198,8 @@ forward graph of the parent module,
"""
if prepare_custom_config is None:
prepare_custom_config = PrepareCustomConfig()
if equalization_config is None:
equalization_config = QConfigMapping()
if _equalization_config is None:
_equalization_config = QConfigMapping()
if isinstance(prepare_custom_config, Dict):
warnings.warn(
@ -238,7 +238,7 @@ forward graph of the parent module,
tracer.node_name_to_scope,
example_inputs=example_inputs,
prepare_custom_config=prepare_custom_config,
equalization_config=equalization_config,
_equalization_config=_equalization_config,
backend_config_dict=backend_config_dict,
is_standalone_module=is_standalone_module,
) # type: ignore[operator]
@ -337,7 +337,7 @@ def prepare_fx(
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> ObservedGraphModule:
r""" Prepare a model for post training static quantization
@ -379,7 +379,7 @@ def prepare_fx(
.set_output_quantized_indexes([0]) \
.set_preserved_attributes(["attr1", "attr2"])
* `equalization_config`: config for specifying how to perform equalization on the model
* `_equalization_config`: config for specifying how to perform equalization on the model
* `backend_config_dict`: a dictionary that specifies how operators are quantized
in a backend, this includes how the operaetors are observed,
@ -417,7 +417,7 @@ def prepare_fx(
False, # is_qat
example_inputs,
prepare_custom_config,
equalization_config,
_equalization_config,
backend_config_dict,
)