mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73274 As noticed in https://discuss.pytorch.org/t/calibration-of-model-in-post-training-static-quantization-using-fx-api/143661/6 and related to https://github.com/pytorch/pytorch/issues/72698 when using fx quantizaiton, if an op like view was used in a model and the index parameters were passed in to the ops with a variable rather than hard coded, fx would mistakenly insert observers for them, leading to an error when the observer tried to do tensor only operations on a non-tensor. To fix this, an API was added to specify non tensor arguments for various ops to enable better dtype propagation. NON_TENSOR_ARG_DICT is a nested dict whose first key is a named tuple which contains matching parameters for ops with nontensor args, the inner dict's keys are dtypes and the values are a list of those arg indices that take use such dtypes. Alternatively, instead of a list, the inner dict value can also be a function that takes the node as an argument and returns the list of arg indices. Theoretically this api can support arbitrary functions but the current implmentation is limited to simpler functions given the particular issue this fixes seems to be rare. Note: although torch.unsqueeze and torch.transpose are listed in quantization_patterns.py, those ops appear to be untraceable by fx. I've included tests for their cases but fixing this issue is beyond the scope of this PR Test Plan: python test/test_quantization.py test_non_reference_size ... python test/test_quantization.py test_non_reference_<op> Imported from OSS Reviewed By: jerryzh168 Differential Revision: D34410122 fbshipit-source-id: fc09949ca8a2d6473876a4b6c214eb91e9a9dae2 (cherry picked from commit 3a1375d677b7c98d62b1f5c839645698c39b32b9)
205 lines
6.9 KiB
Python
205 lines
6.9 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
from .common import AOMigrationTestCase
|
|
|
|
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
|
def test_package_import_quantize_fx(self):
|
|
self._test_package_import('quantize_fx')
|
|
|
|
def test_function_import_quantize_fx(self):
|
|
function_list = [
|
|
'_check_is_graph_module',
|
|
'_swap_ff_with_fxff',
|
|
'_fuse_fx',
|
|
'Scope',
|
|
'ScopeContextManager',
|
|
'QuantizationTracer',
|
|
'_prepare_fx',
|
|
'_prepare_standalone_module_fx',
|
|
'fuse_fx',
|
|
'prepare_fx',
|
|
'prepare_qat_fx',
|
|
'_convert_fx',
|
|
'convert_fx',
|
|
'_convert_standalone_module_fx',
|
|
]
|
|
self._test_function_import('quantize_fx', function_list)
|
|
|
|
def test_package_import_fx(self):
|
|
self._test_package_import('fx')
|
|
|
|
def test_function_import_fx(self):
|
|
function_list = [
|
|
'prepare',
|
|
'convert',
|
|
'fuse',
|
|
]
|
|
self._test_function_import('fx', function_list)
|
|
|
|
def test_package_import_fx_graph_module(self):
|
|
self._test_package_import('fx.graph_module')
|
|
|
|
def test_function_import_fx_graph_module(self):
|
|
function_list = [
|
|
'FusedGraphModule',
|
|
'ObservedGraphModule',
|
|
'is_observed_module',
|
|
'ObservedStandaloneGraphModule',
|
|
'is_observed_standalone_module',
|
|
'QuantizedGraphModule'
|
|
]
|
|
self._test_function_import('fx.graph_module', function_list)
|
|
|
|
def test_package_import_fx_pattern_utils(self):
|
|
self._test_package_import('fx.pattern_utils')
|
|
|
|
def test_function_import_fx_pattern_utils(self):
|
|
function_list = [
|
|
'QuantizeHandler',
|
|
'MatchResult',
|
|
'register_fusion_pattern',
|
|
'get_default_fusion_patterns',
|
|
'register_quant_pattern',
|
|
'get_default_quant_patterns',
|
|
'get_default_output_activation_post_process_map'
|
|
]
|
|
self._test_function_import('fx.pattern_utils', function_list)
|
|
|
|
def test_package_import_fx_equalize(self):
|
|
self._test_package_import('fx._equalize')
|
|
|
|
def test_function_import_fx_equalize(self):
|
|
function_list = [
|
|
'reshape_scale',
|
|
'_InputEqualizationObserver',
|
|
'_WeightEqualizationObserver',
|
|
'calculate_equalization_scale',
|
|
'EqualizationQConfig',
|
|
'input_equalization_observer',
|
|
'weight_equalization_observer',
|
|
'default_equalization_qconfig',
|
|
'fused_module_supports_equalization',
|
|
'nn_module_supports_equalization',
|
|
'node_supports_equalization',
|
|
'is_equalization_observer',
|
|
'get_op_node_and_weight_eq_obs',
|
|
'maybe_get_weight_eq_obs_node',
|
|
'maybe_get_next_input_eq_obs',
|
|
'maybe_get_next_equalization_scale',
|
|
'scale_input_observer',
|
|
'scale_weight_node',
|
|
'scale_weight_functional',
|
|
'clear_weight_quant_obs_node',
|
|
'remove_node',
|
|
'update_obs_for_equalization',
|
|
'convert_eq_obs',
|
|
'_convert_equalization_ref',
|
|
'get_layer_sqnr_dict',
|
|
'get_equalization_qconfig_dict'
|
|
]
|
|
self._test_function_import('fx._equalize', function_list)
|
|
|
|
def test_package_import_fx_quantization_patterns(self):
|
|
self._test_package_import('fx.quantization_patterns')
|
|
|
|
def test_function_import_fx_quantization_patterns(self):
|
|
function_list = [
|
|
'QuantizeHandler',
|
|
'BinaryOpQuantizeHandler',
|
|
'CatQuantizeHandler',
|
|
'ConvReluQuantizeHandler',
|
|
'LinearReLUQuantizeHandler',
|
|
'BatchNormQuantizeHandler',
|
|
'EmbeddingQuantizeHandler',
|
|
'RNNDynamicQuantizeHandler',
|
|
'DefaultNodeQuantizeHandler',
|
|
'FixedQParamsOpQuantizeHandler',
|
|
'CopyNodeQuantizeHandler',
|
|
'CustomModuleQuantizeHandler',
|
|
'GeneralTensorShapeOpQuantizeHandler',
|
|
'StandaloneModuleQuantizeHandler'
|
|
]
|
|
self._test_function_import('fx.quantization_patterns', function_list)
|
|
|
|
def test_package_import_fx_match_utils(self):
|
|
self._test_package_import('fx.match_utils')
|
|
|
|
def test_function_import_fx_match_utils(self):
|
|
function_list = [
|
|
'MatchResult',
|
|
'MatchAllNode',
|
|
'is_match',
|
|
'find_matches'
|
|
]
|
|
self._test_function_import('fx.match_utils', function_list)
|
|
|
|
def test_package_import_fx_prepare(self):
|
|
self._test_package_import('fx.prepare')
|
|
|
|
def test_function_import_fx_prepare(self):
|
|
function_list = [
|
|
'prepare'
|
|
]
|
|
self._test_function_import('fx.prepare', function_list)
|
|
|
|
def test_package_import_fx_convert(self):
|
|
self._test_package_import('fx.convert')
|
|
|
|
def test_function_import_fx_convert(self):
|
|
function_list = [
|
|
'convert'
|
|
]
|
|
self._test_function_import('fx.convert', function_list)
|
|
|
|
def test_package_import_fx_fuse(self):
|
|
self._test_package_import('fx.fuse')
|
|
|
|
def test_function_import_fx_fuse(self):
|
|
function_list = ['fuse']
|
|
self._test_function_import('fx.fuse', function_list)
|
|
|
|
def test_package_import_fx_fusion_patterns(self):
|
|
self._test_package_import('fx.fusion_patterns')
|
|
|
|
def test_function_import_fx_fusion_patterns(self):
|
|
function_list = [
|
|
'FuseHandler',
|
|
'DefaultFuseHandler'
|
|
]
|
|
self._test_function_import('fx.fusion_patterns', function_list)
|
|
|
|
def test_package_import_fx_quantization_types(self):
|
|
self._test_package_import('fx.quantization_types')
|
|
|
|
def test_function_import_fx_quantization_types(self):
|
|
function_list = [
|
|
'Pattern',
|
|
'QuantizerCls'
|
|
]
|
|
self._test_function_import('fx.quantization_types', function_list)
|
|
|
|
def test_package_import_fx_utils(self):
|
|
self._test_package_import('fx.utils')
|
|
|
|
def test_function_import_fx_utils(self):
|
|
function_list = [
|
|
'graph_pretty_str',
|
|
'get_per_tensor_qparams',
|
|
'quantize_node',
|
|
'get_custom_module_class_keys',
|
|
'get_linear_prepack_op_for_dtype',
|
|
'get_qconv_prepack_op',
|
|
'get_qconv_op',
|
|
'get_new_attr_name_with_prefix',
|
|
'graph_module_from_producer_nodes',
|
|
'assert_and_get_unique_device',
|
|
'create_getattr_from_value',
|
|
'create_qparam_nodes',
|
|
'all_node_args_have_no_tensors',
|
|
'node_return_type_is_int',
|
|
'get_non_observable_arg_indexes_and_types',
|
|
'is_get_tensor_info_node',
|
|
'maybe_get_next_module'
|
|
]
|
|
self._test_function_import('fx.utils', function_list)
|