mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In quantization tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728 Approved by: https://github.com/ezyang
159 lines
5.5 KiB
Python
159 lines
5.5 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
|
|
from .common import AOMigrationTestCase
|
|
|
|
|
|
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
|
def test_function_import_quantize_fx(self):
|
|
function_list = [
|
|
"_check_is_graph_module",
|
|
"_swap_ff_with_fxff",
|
|
"_fuse_fx",
|
|
"QuantizationTracer",
|
|
"_prepare_fx",
|
|
"_prepare_standalone_module_fx",
|
|
"fuse_fx",
|
|
"Scope",
|
|
"ScopeContextManager",
|
|
"prepare_fx",
|
|
"prepare_qat_fx",
|
|
"_convert_fx",
|
|
"convert_fx",
|
|
"_convert_standalone_module_fx",
|
|
]
|
|
self._test_function_import("quantize_fx", function_list)
|
|
|
|
def test_function_import_fx(self):
|
|
function_list = [
|
|
"prepare",
|
|
"convert",
|
|
"fuse",
|
|
]
|
|
self._test_function_import("fx", function_list)
|
|
|
|
def test_function_import_fx_graph_module(self):
|
|
function_list = [
|
|
"FusedGraphModule",
|
|
"ObservedGraphModule",
|
|
"_is_observed_module",
|
|
"ObservedStandaloneGraphModule",
|
|
"_is_observed_standalone_module",
|
|
"QuantizedGraphModule",
|
|
]
|
|
self._test_function_import("fx.graph_module", function_list)
|
|
|
|
def test_function_import_fx_pattern_utils(self):
|
|
function_list = [
|
|
"QuantizeHandler",
|
|
"_register_fusion_pattern",
|
|
"get_default_fusion_patterns",
|
|
"_register_quant_pattern",
|
|
"get_default_quant_patterns",
|
|
"get_default_output_activation_post_process_map",
|
|
]
|
|
self._test_function_import("fx.pattern_utils", function_list)
|
|
|
|
def test_function_import_fx_equalize(self):
|
|
function_list = [
|
|
"reshape_scale",
|
|
"_InputEqualizationObserver",
|
|
"_WeightEqualizationObserver",
|
|
"calculate_equalization_scale",
|
|
"EqualizationQConfig",
|
|
"input_equalization_observer",
|
|
"weight_equalization_observer",
|
|
"default_equalization_qconfig",
|
|
"fused_module_supports_equalization",
|
|
"nn_module_supports_equalization",
|
|
"node_supports_equalization",
|
|
"is_equalization_observer",
|
|
"get_op_node_and_weight_eq_obs",
|
|
"maybe_get_weight_eq_obs_node",
|
|
"maybe_get_next_input_eq_obs",
|
|
"maybe_get_next_equalization_scale",
|
|
"scale_input_observer",
|
|
"scale_weight_node",
|
|
"scale_weight_functional",
|
|
"clear_weight_quant_obs_node",
|
|
"remove_node",
|
|
"update_obs_for_equalization",
|
|
"convert_eq_obs",
|
|
"_convert_equalization_ref",
|
|
"get_layer_sqnr_dict",
|
|
"get_equalization_qconfig_dict",
|
|
]
|
|
self._test_function_import("fx._equalize", function_list)
|
|
|
|
def test_function_import_fx_quantization_patterns(self):
|
|
function_list = [
|
|
"QuantizeHandler",
|
|
"BinaryOpQuantizeHandler",
|
|
"CatQuantizeHandler",
|
|
"ConvReluQuantizeHandler",
|
|
"LinearReLUQuantizeHandler",
|
|
"BatchNormQuantizeHandler",
|
|
"EmbeddingQuantizeHandler",
|
|
"RNNDynamicQuantizeHandler",
|
|
"DefaultNodeQuantizeHandler",
|
|
"FixedQParamsOpQuantizeHandler",
|
|
"CopyNodeQuantizeHandler",
|
|
"CustomModuleQuantizeHandler",
|
|
"GeneralTensorShapeOpQuantizeHandler",
|
|
"StandaloneModuleQuantizeHandler",
|
|
]
|
|
self._test_function_import(
|
|
"fx.quantization_patterns",
|
|
function_list,
|
|
new_package_name="fx.quantize_handler",
|
|
)
|
|
|
|
def test_function_import_fx_match_utils(self):
|
|
function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"]
|
|
self._test_function_import("fx.match_utils", function_list)
|
|
|
|
def test_function_import_fx_prepare(self):
|
|
function_list = ["prepare"]
|
|
self._test_function_import("fx.prepare", function_list)
|
|
|
|
def test_function_import_fx_convert(self):
|
|
function_list = ["convert"]
|
|
self._test_function_import("fx.convert", function_list)
|
|
|
|
def test_function_import_fx_fuse(self):
|
|
function_list = ["fuse"]
|
|
self._test_function_import("fx.fuse", function_list)
|
|
|
|
def test_function_import_fx_fusion_patterns(self):
|
|
function_list = ["FuseHandler", "DefaultFuseHandler"]
|
|
self._test_function_import(
|
|
"fx.fusion_patterns",
|
|
function_list,
|
|
new_package_name="fx.fuse_handler",
|
|
)
|
|
|
|
# we removed matching test for torch.quantization.fx.quantization_types
|
|
# old: torch.quantization.fx.quantization_types
|
|
# new: torch.ao.quantization.utils
|
|
# both are valid, but we'll deprecate the old path in the future
|
|
|
|
def test_function_import_fx_utils(self):
|
|
function_list = [
|
|
"get_custom_module_class_keys",
|
|
"get_linear_prepack_op_for_dtype",
|
|
"get_qconv_prepack_op",
|
|
"get_new_attr_name_with_prefix",
|
|
"graph_module_from_producer_nodes",
|
|
"assert_and_get_unique_device",
|
|
"create_getattr_from_value",
|
|
"all_node_args_have_no_tensors",
|
|
"get_non_observable_arg_indexes_and_types",
|
|
"maybe_get_next_module",
|
|
]
|
|
self._test_function_import("fx.utils", function_list)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_quantization.py")
|