mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637 The previous naming convention `default_affine_fixed_qparams_observer` and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read the definition in order to understand what these observers are. The new naming convention reveals information about the range of the observers The analogous changes were also made for `default_symmetric_fixed_qparams_fake_quant` and `default_affine_fixed_qparams_fake_quant` Test Plan: ``` python test/test_quantization.py ``` ``` python test/test_quantization.py ``` Differential Revision: D36054169 D36054169 Reviewed By: vkuzo Pulled By: dzdang fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9 (cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
from .common import AOMigrationTestCase
|
|
|
|
|
|
class TestAOMigrationQuantization(AOMigrationTestCase):
|
|
r"""Modules and functions related to the
|
|
`torch/quantization` migration to `torch/ao/quantization`.
|
|
"""
|
|
def test_package_import_quantize(self):
|
|
self._test_package_import('quantize')
|
|
|
|
def test_function_import_quantize(self):
|
|
function_list = [
|
|
'_convert',
|
|
'_observer_forward_hook',
|
|
'_propagate_qconfig_helper',
|
|
'_remove_activation_post_process',
|
|
'_remove_qconfig',
|
|
'add_observer_',
|
|
'add_quant_dequant',
|
|
'convert',
|
|
'get_observer_dict',
|
|
'get_unique_devices_',
|
|
'is_activation_post_process',
|
|
'prepare',
|
|
'prepare_qat',
|
|
'propagate_qconfig_',
|
|
'quantize',
|
|
'quantize_dynamic',
|
|
'quantize_qat',
|
|
'register_activation_post_process_hook',
|
|
'swap_module',
|
|
]
|
|
self._test_function_import('quantize', function_list)
|
|
|
|
def test_package_import_stubs(self):
|
|
self._test_package_import('stubs')
|
|
|
|
def test_function_import_stubs(self):
|
|
function_list = [
|
|
'QuantStub',
|
|
'DeQuantStub',
|
|
'QuantWrapper',
|
|
]
|
|
self._test_function_import('stubs', function_list)
|
|
|
|
def test_package_import_quantize_jit(self):
|
|
self._test_package_import('quantize_jit')
|
|
|
|
def test_function_import_quantize_jit(self):
|
|
function_list = [
|
|
'_check_is_script_module',
|
|
'_check_forward_method',
|
|
'script_qconfig',
|
|
'script_qconfig_dict',
|
|
'fuse_conv_bn_jit',
|
|
'_prepare_jit',
|
|
'prepare_jit',
|
|
'prepare_dynamic_jit',
|
|
'_convert_jit',
|
|
'convert_jit',
|
|
'convert_dynamic_jit',
|
|
'_quantize_jit',
|
|
'quantize_jit',
|
|
'quantize_dynamic_jit',
|
|
]
|
|
self._test_function_import('quantize_jit', function_list)
|
|
|
|
def test_package_import_fake_quantize(self):
|
|
self._test_package_import('fake_quantize')
|
|
|
|
def test_function_import_fake_quantize(self):
|
|
function_list = [
|
|
'_is_per_channel',
|
|
'_is_per_tensor',
|
|
'_is_symmetric_quant',
|
|
'FakeQuantizeBase',
|
|
'FakeQuantize',
|
|
'FixedQParamsFakeQuantize',
|
|
'FusedMovingAvgObsFakeQuantize',
|
|
'default_fake_quant',
|
|
'default_weight_fake_quant',
|
|
'default_fixed_qparams_range_neg1to1_fake_quant',
|
|
'default_fixed_qparams_range_0to1_fake_quant',
|
|
'default_per_channel_weight_fake_quant',
|
|
'default_histogram_fake_quant',
|
|
'default_fused_act_fake_quant',
|
|
'default_fused_wt_fake_quant',
|
|
'default_fused_per_channel_wt_fake_quant',
|
|
'_is_fake_quant_script_module',
|
|
'disable_fake_quant',
|
|
'enable_fake_quant',
|
|
'disable_observer',
|
|
'enable_observer',
|
|
]
|
|
self._test_function_import('fake_quantize', function_list)
|
|
|
|
def test_package_import_fuse_modules(self):
|
|
self._test_package_import('fuse_modules')
|
|
|
|
def test_function_import_fuse_modules(self):
|
|
function_list = [
|
|
'_fuse_modules',
|
|
'_get_module',
|
|
'_set_module',
|
|
'fuse_conv_bn',
|
|
'fuse_conv_bn_relu',
|
|
'fuse_known_modules',
|
|
'fuse_modules',
|
|
'get_fuser_method',
|
|
]
|
|
self._test_function_import('fuse_modules', function_list)
|
|
|
|
def test_package_import_quant_type(self):
|
|
self._test_package_import('quant_type')
|
|
|
|
def test_function_import_quant_type(self):
|
|
function_list = [
|
|
'QuantType',
|
|
'quant_type_to_str',
|
|
]
|
|
self._test_function_import('quant_type', function_list)
|
|
|
|
def test_package_import_observer(self):
|
|
self._test_package_import('observer')
|
|
|
|
def test_function_import_observer(self):
|
|
function_list = [
|
|
"_PartialWrapper",
|
|
"_with_args",
|
|
"_with_callable_args",
|
|
"ABC",
|
|
"ObserverBase",
|
|
"_ObserverBase",
|
|
"MinMaxObserver",
|
|
"MovingAverageMinMaxObserver",
|
|
"PerChannelMinMaxObserver",
|
|
"MovingAveragePerChannelMinMaxObserver",
|
|
"HistogramObserver",
|
|
"PlaceholderObserver",
|
|
"RecordingObserver",
|
|
"NoopObserver",
|
|
"_is_activation_post_process",
|
|
"_is_per_channel_script_obs_instance",
|
|
"get_observer_state_dict",
|
|
"load_observer_state_dict",
|
|
"default_observer",
|
|
"default_placeholder_observer",
|
|
"default_debug_observer",
|
|
"default_weight_observer",
|
|
"default_histogram_observer",
|
|
"default_per_channel_weight_observer",
|
|
"default_dynamic_quant_observer",
|
|
"default_float_qparams_observer",
|
|
]
|
|
self._test_function_import('observer', function_list)
|
|
|
|
def test_package_import_qconfig(self):
|
|
self._test_package_import('qconfig')
|
|
|
|
def test_function_import_qconfig(self):
|
|
function_list = [
|
|
"QConfig",
|
|
"default_qconfig",
|
|
"default_debug_qconfig",
|
|
"default_per_channel_qconfig",
|
|
"QConfigDynamic",
|
|
"default_dynamic_qconfig",
|
|
"float16_dynamic_qconfig",
|
|
"float16_static_qconfig",
|
|
"per_channel_dynamic_qconfig",
|
|
"float_qparams_weight_only_qconfig",
|
|
"default_qat_qconfig",
|
|
"default_weight_only_qconfig",
|
|
"default_activation_only_qconfig",
|
|
"default_qat_qconfig_v2",
|
|
"get_default_qconfig",
|
|
"get_default_qat_qconfig",
|
|
"assert_valid_qconfig",
|
|
"QConfigAny",
|
|
"add_module_to_qconfig_obs_ctr",
|
|
"qconfig_equals"
|
|
]
|
|
self._test_function_import('qconfig', function_list)
|
|
|
|
def test_package_import_quantization_mappings(self):
|
|
self._test_package_import('quantization_mappings')
|
|
|
|
def test_function_import_quantization_mappings(self):
|
|
function_list = [
|
|
"no_observer_set",
|
|
"get_default_static_quant_module_mappings",
|
|
"get_static_quant_module_class",
|
|
"get_dynamic_quant_module_class",
|
|
"get_default_qat_module_mappings",
|
|
"get_default_dynamic_quant_module_mappings",
|
|
"get_default_qconfig_propagation_list",
|
|
"get_default_compare_output_module_list",
|
|
"get_default_float_to_quantized_operator_mappings",
|
|
"get_quantized_operator",
|
|
"_get_special_act_post_process",
|
|
"_has_special_act_post_process",
|
|
]
|
|
dict_list = [
|
|
"DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
|
|
"DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
|
|
"DEFAULT_QAT_MODULE_MAPPINGS",
|
|
"DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
|
|
# "_INCLUDE_QCONFIG_PROPAGATE_LIST",
|
|
"DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
|
|
"DEFAULT_MODULE_TO_ACT_POST_PROCESS",
|
|
]
|
|
self._test_function_import('quantization_mappings', function_list)
|
|
self._test_dict_import('quantization_mappings', dict_list)
|
|
|
|
def test_package_import_fuser_method_mappings(self):
|
|
self._test_package_import('fuser_method_mappings')
|
|
|
|
def test_function_import_fuser_method_mappings(self):
|
|
function_list = [
|
|
"fuse_conv_bn",
|
|
"fuse_conv_bn_relu",
|
|
"fuse_linear_bn",
|
|
"get_fuser_method",
|
|
]
|
|
dict_list = [
|
|
"DEFAULT_OP_LIST_TO_FUSER_METHOD"
|
|
]
|
|
self._test_function_import('fuser_method_mappings', function_list)
|
|
self._test_dict_import('fuser_method_mappings', dict_list)
|
|
|
|
def test_package_import_utils(self):
|
|
self._test_package_import('utils')
|
|
|
|
def test_function_import_utils(self):
|
|
function_list = [
|
|
'activation_dtype',
|
|
'activation_is_int8_quantized',
|
|
'activation_is_statically_quantized',
|
|
'calculate_qmin_qmax',
|
|
'check_min_max_valid',
|
|
'get_combined_dict',
|
|
'get_qconfig_dtypes',
|
|
'get_qparam_dict',
|
|
'get_quant_type',
|
|
'get_swapped_custom_module_class',
|
|
'getattr_from_fqn',
|
|
'is_per_channel',
|
|
'is_per_tensor',
|
|
'weight_dtype',
|
|
'weight_is_quantized',
|
|
'weight_is_statically_quantized',
|
|
]
|
|
self._test_function_import('utils', function_list)
|