pytorch/test/quantization/ao_migration/test_quantization_fx.py
Jerry Zhang bf089840ac [quant][graphmode][fx] Enable fuse handler for sequence of 3 ops (#69658)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69658

This PR enables fuse handler for sequence of three ops, and merges all fuse handlers into one

TODO: we can also move this to backend_config_dict folder

Test Plan:
regression fusion test
```
python test/test_quantization.py TestFuseFx
```

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D32974907

fbshipit-source-id: ba205e74b566814145f776257c5f5bb3b24547c1
2021-12-14 19:04:21 -08:00

208 lines
6.9 KiB
Python

# Owner(s): ["oncall: quantization"]
from .common import AOMigrationTestCase
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_package_import_quantize_fx(self):
self._test_package_import('quantize_fx')
def test_function_import_quantize_fx(self):
function_list = [
'_check_is_graph_module',
'_swap_ff_with_fxff',
'_fuse_fx',
'Scope',
'ScopeContextManager',
'QuantizationTracer',
'_prepare_fx',
'_prepare_standalone_module_fx',
'fuse_fx',
'prepare_fx',
'prepare_qat_fx',
'_convert_fx',
'convert_fx',
'_convert_standalone_module_fx',
]
self._test_function_import('quantize_fx', function_list)
def test_package_import_fx(self):
self._test_package_import('fx')
def test_function_import_fx(self):
function_list = [
'prepare',
'convert',
'Fuser',
]
self._test_function_import('fx', function_list)
def test_package_import_fx_graph_module(self):
self._test_package_import('fx.graph_module')
def test_function_import_fx_graph_module(self):
function_list = [
'FusedGraphModule',
'ObservedGraphModule',
'is_observed_module',
'ObservedStandaloneGraphModule',
'is_observed_standalone_module',
'QuantizedGraphModule'
]
self._test_function_import('fx.graph_module', function_list)
def test_package_import_fx_pattern_utils(self):
self._test_package_import('fx.pattern_utils')
def test_function_import_fx_pattern_utils(self):
function_list = [
'QuantizeHandler',
'MatchResult',
'register_fusion_pattern',
'get_default_fusion_patterns',
'register_quant_pattern',
'get_default_quant_patterns',
'get_default_output_activation_post_process_map'
]
self._test_function_import('fx.pattern_utils', function_list)
def test_package_import_fx_equalize(self):
self._test_package_import('fx._equalize')
def test_function_import_fx_equalize(self):
function_list = [
'reshape_scale',
'_InputEqualizationObserver',
'_WeightEqualizationObserver',
'calculate_equalization_scale',
'EqualizationQConfig',
'input_equalization_observer',
'weight_equalization_observer',
'default_equalization_qconfig',
'fused_module_supports_equalization',
'nn_module_supports_equalization',
'node_supports_equalization',
'is_equalization_observer',
'get_op_node_and_weight_eq_obs',
'maybe_get_weight_eq_obs_node',
'maybe_get_next_input_eq_obs',
'maybe_get_next_equalization_scale',
'scale_input_observer',
'scale_weight_node',
'scale_weight_functional',
'clear_weight_quant_obs_node',
'remove_node',
'update_obs_for_equalization',
'convert_eq_obs',
'_convert_equalization_ref',
'get_layer_sqnr_dict',
'get_equalization_qconfig_dict'
]
self._test_function_import('fx._equalize', function_list)
def test_package_import_fx_quantization_patterns(self):
self._test_package_import('fx.quantization_patterns')
def test_function_import_fx_quantization_patterns(self):
function_list = [
'QuantizeHandler',
'BinaryOpQuantizeHandler',
'CatQuantizeHandler',
'ConvReluQuantizeHandler',
'LinearReLUQuantizeHandler',
'BatchNormQuantizeHandler',
'EmbeddingQuantizeHandler',
'RNNDynamicQuantizeHandler',
'DefaultNodeQuantizeHandler',
'FixedQParamsOpQuantizeHandler',
'CopyNodeQuantizeHandler',
'CustomModuleQuantizeHandler',
'GeneralTensorShapeOpQuantizeHandler',
'StandaloneModuleQuantizeHandler'
]
self._test_function_import('fx.quantization_patterns', function_list)
def test_package_import_fx_match_utils(self):
self._test_package_import('fx.match_utils')
def test_function_import_fx_match_utils(self):
function_list = [
'MatchResult',
'MatchAllNode',
'is_match',
'find_matches'
]
self._test_function_import('fx.match_utils', function_list)
def test_package_import_fx_prepare(self):
self._test_package_import('fx.prepare')
def test_function_import_fx_prepare(self):
function_list = [
'prepare'
]
self._test_function_import('fx.prepare', function_list)
def test_package_import_fx_convert(self):
self._test_package_import('fx.convert')
def test_function_import_fx_convert(self):
function_list = [
'convert'
]
self._test_function_import('fx.convert', function_list)
def test_package_import_fx_fuse(self):
self._test_package_import('fx.fuse')
def test_function_import_fx_fuse(self):
function_list = [
'Fuser'
]
self._test_function_import('fx.fuse', function_list)
def test_package_import_fx_fusion_patterns(self):
self._test_package_import('fx.fusion_patterns')
def test_function_import_fx_fusion_patterns(self):
function_list = [
'FuseHandler',
'DefaultFuseHandler'
]
self._test_function_import('fx.fusion_patterns', function_list)
def test_package_import_fx_quantization_types(self):
self._test_package_import('fx.quantization_types')
def test_function_import_fx_quantization_types(self):
function_list = [
'Pattern',
'QuantizerCls'
]
self._test_function_import('fx.quantization_types', function_list)
def test_package_import_fx_utils(self):
self._test_package_import('fx.utils')
def test_function_import_fx_utils(self):
function_list = [
'_parent_name',
'graph_pretty_str',
'get_per_tensor_qparams',
'quantize_node',
'get_custom_module_class_keys',
'get_linear_prepack_op_for_dtype',
'get_qconv_prepack_op',
'get_qconv_op',
'get_new_attr_name_with_prefix',
'graph_module_from_producer_nodes',
'assert_and_get_unique_device',
'create_getattr_from_value',
'create_qparam_nodes',
'all_node_args_have_no_tensors',
'node_return_type_is_int',
'node_bool_tensor_arg_indexes',
'is_get_tensor_info_node',
'maybe_get_next_module'
]
self._test_function_import('fx.utils', function_list)