pytorch/test/quantization/ao_migration/test_quantization_fx.py
Anthony Barbier 954ce94950 Add __main__ guards to quantization tests (#154728)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In quantization tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728
Approved by: https://github.com/ezyang
2025-06-10 19:46:07 +00:00

159 lines
5.5 KiB
Python

# Owner(s): ["oncall: quantization"]
from torch.testing._internal.common_utils import raise_on_run_directly
from .common import AOMigrationTestCase
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_quantize_fx(self):
function_list = [
"_check_is_graph_module",
"_swap_ff_with_fxff",
"_fuse_fx",
"QuantizationTracer",
"_prepare_fx",
"_prepare_standalone_module_fx",
"fuse_fx",
"Scope",
"ScopeContextManager",
"prepare_fx",
"prepare_qat_fx",
"_convert_fx",
"convert_fx",
"_convert_standalone_module_fx",
]
self._test_function_import("quantize_fx", function_list)
def test_function_import_fx(self):
function_list = [
"prepare",
"convert",
"fuse",
]
self._test_function_import("fx", function_list)
def test_function_import_fx_graph_module(self):
function_list = [
"FusedGraphModule",
"ObservedGraphModule",
"_is_observed_module",
"ObservedStandaloneGraphModule",
"_is_observed_standalone_module",
"QuantizedGraphModule",
]
self._test_function_import("fx.graph_module", function_list)
def test_function_import_fx_pattern_utils(self):
function_list = [
"QuantizeHandler",
"_register_fusion_pattern",
"get_default_fusion_patterns",
"_register_quant_pattern",
"get_default_quant_patterns",
"get_default_output_activation_post_process_map",
]
self._test_function_import("fx.pattern_utils", function_list)
def test_function_import_fx_equalize(self):
function_list = [
"reshape_scale",
"_InputEqualizationObserver",
"_WeightEqualizationObserver",
"calculate_equalization_scale",
"EqualizationQConfig",
"input_equalization_observer",
"weight_equalization_observer",
"default_equalization_qconfig",
"fused_module_supports_equalization",
"nn_module_supports_equalization",
"node_supports_equalization",
"is_equalization_observer",
"get_op_node_and_weight_eq_obs",
"maybe_get_weight_eq_obs_node",
"maybe_get_next_input_eq_obs",
"maybe_get_next_equalization_scale",
"scale_input_observer",
"scale_weight_node",
"scale_weight_functional",
"clear_weight_quant_obs_node",
"remove_node",
"update_obs_for_equalization",
"convert_eq_obs",
"_convert_equalization_ref",
"get_layer_sqnr_dict",
"get_equalization_qconfig_dict",
]
self._test_function_import("fx._equalize", function_list)
def test_function_import_fx_quantization_patterns(self):
function_list = [
"QuantizeHandler",
"BinaryOpQuantizeHandler",
"CatQuantizeHandler",
"ConvReluQuantizeHandler",
"LinearReLUQuantizeHandler",
"BatchNormQuantizeHandler",
"EmbeddingQuantizeHandler",
"RNNDynamicQuantizeHandler",
"DefaultNodeQuantizeHandler",
"FixedQParamsOpQuantizeHandler",
"CopyNodeQuantizeHandler",
"CustomModuleQuantizeHandler",
"GeneralTensorShapeOpQuantizeHandler",
"StandaloneModuleQuantizeHandler",
]
self._test_function_import(
"fx.quantization_patterns",
function_list,
new_package_name="fx.quantize_handler",
)
def test_function_import_fx_match_utils(self):
function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"]
self._test_function_import("fx.match_utils", function_list)
def test_function_import_fx_prepare(self):
function_list = ["prepare"]
self._test_function_import("fx.prepare", function_list)
def test_function_import_fx_convert(self):
function_list = ["convert"]
self._test_function_import("fx.convert", function_list)
def test_function_import_fx_fuse(self):
function_list = ["fuse"]
self._test_function_import("fx.fuse", function_list)
def test_function_import_fx_fusion_patterns(self):
function_list = ["FuseHandler", "DefaultFuseHandler"]
self._test_function_import(
"fx.fusion_patterns",
function_list,
new_package_name="fx.fuse_handler",
)
# we removed matching test for torch.quantization.fx.quantization_types
# old: torch.quantization.fx.quantization_types
# new: torch.ao.quantization.utils
# both are valid, but we'll deprecate the old path in the future
def test_function_import_fx_utils(self):
function_list = [
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_qconv_prepack_op",
"get_new_attr_name_with_prefix",
"graph_module_from_producer_nodes",
"assert_and_get_unique_device",
"create_getattr_from_value",
"all_node_args_have_no_tensors",
"get_non_observable_arg_indexes_and_types",
"maybe_get_next_module",
]
self._test_function_import("fx.utils", function_list)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")