mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Quant][fx] Hide equalization_config from prepare APIs (#80164)
Summary: This PR hides the equalization_config argument from prepare_fx. This is a private API that we do not wish to expose to users and have to maintain backward compatibility for. Test Plan: python test/test_quantization.py TestEqualizeFx Reviewers: jerryzh168 Subscribers: jerryzh168 Differential Revision: [D37394353](https://our.internmc.facebook.com/intern/diff/D37394353) Pull Request resolved: https://github.com/pytorch/pytorch/pull/80164 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
8a45ef23f5
commit
8aedd8fb25
|
|
@ -279,7 +279,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
m,
|
||||
specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
|
||||
|
||||
def test_input_weight_equalization_branching(self):
|
||||
|
|
@ -313,7 +313,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
example_inputs = (torch.rand(1, 5),)
|
||||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict, example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)
|
||||
|
||||
# Tests that we will add an equalization observer because there is only
|
||||
|
|
@ -337,7 +337,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
example_inputs = (torch.randn(1, 5),)
|
||||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict, example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
|
|
@ -369,7 +369,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
copy.deepcopy(m),
|
||||
specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict
|
||||
_equalization_config=default_equalization_qconfig_dict
|
||||
)
|
||||
output = prepared(x)
|
||||
|
||||
|
|
@ -379,7 +379,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_fx(prepared) # Check if compile
|
||||
self.assertEqual(output, convert_ref_output)
|
||||
|
|
@ -431,7 +431,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
prepared(*example_inputs)
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref(x)
|
||||
|
|
@ -484,7 +484,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref(x)
|
||||
|
|
@ -544,7 +544,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
prepared(x)
|
||||
convert_ref = _convert_equalization_ref(prepared)
|
||||
convert_ref(x)
|
||||
|
|
@ -783,7 +783,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
prepared = prepare_fx(
|
||||
m, specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict)
|
||||
_equalization_config=default_equalization_qconfig_dict)
|
||||
equalized_quantized_model = convert_fx(prepared)
|
||||
|
||||
# Check the order of nodes in the graph
|
||||
|
|
@ -808,7 +808,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
copy.deepcopy(m),
|
||||
specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config={})
|
||||
_equalization_config={})
|
||||
prepared(x)
|
||||
quantized = convert_fx(prepared) # Check if compile
|
||||
quantized_output = quantized(x)
|
||||
|
|
@ -818,7 +818,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
copy.deepcopy(m),
|
||||
specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=default_equalization_qconfig_dict
|
||||
_equalization_config=default_equalization_qconfig_dict
|
||||
)
|
||||
prepared(x)
|
||||
equalized_and_quantized = convert_fx(prepared) # Check if compile
|
||||
|
|
@ -876,7 +876,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
copy.deepcopy(float_model),
|
||||
specific_qconfig_dict,
|
||||
example_inputs=example_inputs,
|
||||
equalization_config=selective_equalization_qconfig_dict,
|
||||
_equalization_config=selective_equalization_qconfig_dict,
|
||||
)
|
||||
prepared_model(x)
|
||||
equalized_model = convert_fx(prepared_model)
|
||||
|
|
|
|||
|
|
@ -1387,7 +1387,7 @@ def prepare(
|
|||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
example_inputs: Tuple[Any, ...],
|
||||
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
||||
equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
_equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_standalone_module: bool = False) -> ObservedGraphModule:
|
||||
""" standalone_module means it a submodule that is not inlined in
|
||||
|
|
@ -1412,8 +1412,8 @@ def prepare(
|
|||
"""
|
||||
if prepare_custom_config is None:
|
||||
prepare_custom_config = PrepareCustomConfig()
|
||||
if equalization_config is None:
|
||||
equalization_config = QConfigMapping()
|
||||
if _equalization_config is None:
|
||||
_equalization_config = QConfigMapping()
|
||||
|
||||
if isinstance(qconfig_mapping, Dict):
|
||||
warnings.warn(
|
||||
|
|
@ -1421,11 +1421,11 @@ def prepare(
|
|||
"in a future version. Please pass in a QConfigMapping instead.")
|
||||
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
|
||||
|
||||
if isinstance(equalization_config, Dict):
|
||||
if isinstance(_equalization_config, Dict):
|
||||
warnings.warn(
|
||||
"Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
|
||||
"be supported in a future version. Please pass in a QConfigMapping instead.")
|
||||
equalization_config = QConfigMapping.from_dict(equalization_config)
|
||||
_equalization_config = QConfigMapping.from_dict(_equalization_config)
|
||||
|
||||
if isinstance(prepare_custom_config, Dict):
|
||||
warnings.warn(
|
||||
|
|
@ -1434,9 +1434,9 @@ def prepare(
|
|||
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
|
||||
|
||||
assert(isinstance(qconfig_mapping, QConfigMapping))
|
||||
assert(isinstance(equalization_config, QConfigMapping))
|
||||
assert(isinstance(_equalization_config, QConfigMapping))
|
||||
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
||||
equalization_config = copy.deepcopy(equalization_config)
|
||||
_equalization_config = copy.deepcopy(_equalization_config)
|
||||
|
||||
# mapping from a tuple of nodes in reverse order to uninitialized
|
||||
# QuantizeHandler subclass. For example,
|
||||
|
|
@ -1477,7 +1477,7 @@ def prepare(
|
|||
get_fusion_pattern_to_root_node_getter(backend_config_dict)
|
||||
|
||||
update_qconfig_for_fusion(model, qconfig_mapping)
|
||||
update_qconfig_for_fusion(model, equalization_config)
|
||||
update_qconfig_for_fusion(model, _equalization_config)
|
||||
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping)
|
||||
# TODO: support regex as well
|
||||
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
||||
|
|
@ -1498,7 +1498,7 @@ def prepare(
|
|||
|
||||
# fill qconfig_map, a map from node name to qconfig, used in find_matches
|
||||
equalization_qconfig_map = generate_qconfig_map(
|
||||
model, modules, model.graph, equalization_config, node_name_to_scope)
|
||||
model, modules, model.graph, _equalization_config, node_name_to_scope)
|
||||
qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_mapping, node_name_to_scope)
|
||||
|
||||
# match the patterns that will get quantized
|
||||
|
|
|
|||
|
|
@ -181,13 +181,13 @@ def _prepare_fx(
|
|||
is_qat: bool,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
||||
equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
|
||||
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_standalone_module: bool = False,
|
||||
) -> ObservedGraphModule:
|
||||
r""" Internal helper function for prepare_fx
|
||||
Args:
|
||||
`model`, `qconfig_mapping`, `prepare_custom_config`, `equalization_config`:
|
||||
`model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
|
||||
see docs for :func:`~torch.ao.quantization.prepare_fx`
|
||||
`is_standalone_module`: a boolean flag indicates whether we are
|
||||
quantizing a standalone module or not, a standalone module
|
||||
|
|
@ -198,8 +198,8 @@ forward graph of the parent module,
|
|||
"""
|
||||
if prepare_custom_config is None:
|
||||
prepare_custom_config = PrepareCustomConfig()
|
||||
if equalization_config is None:
|
||||
equalization_config = QConfigMapping()
|
||||
if _equalization_config is None:
|
||||
_equalization_config = QConfigMapping()
|
||||
|
||||
if isinstance(prepare_custom_config, Dict):
|
||||
warnings.warn(
|
||||
|
|
@ -238,7 +238,7 @@ forward graph of the parent module,
|
|||
tracer.node_name_to_scope,
|
||||
example_inputs=example_inputs,
|
||||
prepare_custom_config=prepare_custom_config,
|
||||
equalization_config=equalization_config,
|
||||
_equalization_config=_equalization_config,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_standalone_module=is_standalone_module,
|
||||
) # type: ignore[operator]
|
||||
|
|
@ -337,7 +337,7 @@ def prepare_fx(
|
|||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
||||
example_inputs: Tuple[Any, ...],
|
||||
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
||||
equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
|
||||
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> ObservedGraphModule:
|
||||
r""" Prepare a model for post training static quantization
|
||||
|
|
@ -379,7 +379,7 @@ def prepare_fx(
|
|||
.set_output_quantized_indexes([0]) \
|
||||
.set_preserved_attributes(["attr1", "attr2"])
|
||||
|
||||
* `equalization_config`: config for specifying how to perform equalization on the model
|
||||
* `_equalization_config`: config for specifying how to perform equalization on the model
|
||||
|
||||
* `backend_config_dict`: a dictionary that specifies how operators are quantized
|
||||
in a backend, this includes how the operaetors are observed,
|
||||
|
|
@ -417,7 +417,7 @@ def prepare_fx(
|
|||
False, # is_qat
|
||||
example_inputs,
|
||||
prepare_custom_config,
|
||||
equalization_config,
|
||||
_equalization_config,
|
||||
backend_config_dict,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user