pytorch/test/quantization/ao_migration/test_quantization_fx.py
Vasiliy Kuznetsov d549c8de78 fx quant: enable linear-bn1d fusion for PTQ (#66484)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66484

https://github.com/pytorch/pytorch/pull/50748 added linear - bn1d fusion
in Eager mode, for PTQ only. This PR also enables this in FX graph mode.

We reuse the existing conv-bn-relu fusion handler, renaming `conv` to
`conv_or_linear` for readability.

The QAT version is saved for a future PR, for both eager and FX graph.

Test Plan:
```
python test/test_quantization.py TestFuseFx.test_fuse_linear_bn_eval
```

Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D31575392

fbshipit-source-id: f69d80ef37c98cbc070099170e335e250bcdf913
2021-10-18 10:14:28 -07:00

207 lines
6.9 KiB
Python

from .common import AOMigrationTestCase
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_package_import_quantize_fx(self):
self._test_package_import('quantize_fx')
def test_function_import_quantize_fx(self):
function_list = [
'_check_is_graph_module',
'_swap_ff_with_fxff',
'_fuse_fx',
'Scope',
'ScopeContextManager',
'QuantizationTracer',
'_prepare_fx',
'_prepare_standalone_module_fx',
'fuse_fx',
'prepare_fx',
'prepare_qat_fx',
'_convert_fx',
'convert_fx',
'_convert_standalone_module_fx',
]
self._test_function_import('quantize_fx', function_list)
def test_package_import_fx(self):
self._test_package_import('fx')
def test_function_import_fx(self):
function_list = [
'prepare',
'convert',
'Fuser',
]
self._test_function_import('fx', function_list)
def test_package_import_fx_graph_module(self):
self._test_package_import('fx.graph_module')
def test_function_import_fx_graph_module(self):
function_list = [
'FusedGraphModule',
'ObservedGraphModule',
'is_observed_module',
'ObservedStandaloneGraphModule',
'is_observed_standalone_module',
'QuantizedGraphModule'
]
self._test_function_import('fx.graph_module', function_list)
def test_package_import_fx_pattern_utils(self):
self._test_package_import('fx.pattern_utils')
def test_function_import_fx_pattern_utils(self):
function_list = [
'QuantizeHandler',
'MatchResult',
'register_fusion_pattern',
'get_default_fusion_patterns',
'register_quant_pattern',
'get_default_quant_patterns',
'get_default_output_activation_post_process_map'
]
self._test_function_import('fx.pattern_utils', function_list)
def test_package_import_fx_equalize(self):
self._test_package_import('fx._equalize')
def test_function_import_fx_equalize(self):
function_list = [
'reshape_scale',
'_InputEqualizationObserver',
'_WeightEqualizationObserver',
'calculate_equalization_scale',
'EqualizationQConfig',
'input_equalization_observer',
'weight_equalization_observer',
'default_equalization_qconfig',
'fused_module_supports_equalization',
'nn_module_supports_equalization',
'node_supports_equalization',
'is_equalization_observer',
'get_op_node_and_weight_eq_obs',
'maybe_get_weight_eq_obs_node',
'maybe_get_next_input_eq_obs',
'maybe_get_next_equalization_scale',
'scale_input_observer',
'scale_weight_node',
'scale_weight_functional',
'clear_weight_quant_obs_node',
'remove_node',
'update_obs_for_equalization',
'convert_eq_obs',
'_convert_equalization_ref',
'get_layer_sqnr_dict',
'get_equalization_qconfig_dict'
]
self._test_function_import('fx._equalize', function_list)
def test_package_import_fx_quantization_patterns(self):
self._test_package_import('fx.quantization_patterns')
def test_function_import_fx_quantization_patterns(self):
function_list = [
'QuantizeHandler',
'BinaryOpQuantizeHandler',
'CatQuantizeHandler',
'ConvReluQuantizeHandler',
'LinearReLUQuantizeHandler',
'BatchNormQuantizeHandler',
'EmbeddingQuantizeHandler',
'RNNDynamicQuantizeHandler',
'DefaultNodeQuantizeHandler',
'FixedQParamsOpQuantizeHandler',
'CopyNodeQuantizeHandler',
'CustomModuleQuantizeHandler',
'GeneralTensorShapeOpQuantizeHandler',
'StandaloneModuleQuantizeHandler'
]
self._test_function_import('fx.quantization_patterns', function_list)
def test_package_import_fx_match_utils(self):
self._test_package_import('fx.match_utils')
def test_function_import_fx_match_utils(self):
function_list = [
'MatchResult',
'MatchAllNode',
'is_match',
'find_matches'
]
self._test_function_import('fx.match_utils', function_list)
def test_package_import_fx_prepare(self):
self._test_package_import('fx.prepare')
def test_function_import_fx_prepare(self):
function_list = [
'prepare'
]
self._test_function_import('fx.prepare', function_list)
def test_package_import_fx_convert(self):
self._test_package_import('fx.convert')
def test_function_import_fx_convert(self):
function_list = [
'convert'
]
self._test_function_import('fx.convert', function_list)
def test_package_import_fx_fuse(self):
self._test_package_import('fx.fuse')
def test_function_import_fx_fuse(self):
function_list = [
'Fuser'
]
self._test_function_import('fx.fuse', function_list)
def test_package_import_fx_fusion_patterns(self):
self._test_package_import('fx.fusion_patterns')
def test_function_import_fx_fusion_patterns(self):
function_list = [
'FuseHandler',
'ConvOrLinearBNReLUFusion',
'ModuleReLUFusion'
]
self._test_function_import('fx.fusion_patterns', function_list)
def test_package_import_fx_quantization_types(self):
self._test_package_import('fx.quantization_types')
def test_function_import_fx_quantization_types(self):
function_list = [
'Pattern',
'QuantizerCls'
]
self._test_function_import('fx.quantization_types', function_list)
def test_package_import_fx_utils(self):
self._test_package_import('fx.utils')
def test_function_import_fx_utils(self):
function_list = [
'_parent_name',
'graph_pretty_str',
'get_per_tensor_qparams',
'quantize_node',
'get_custom_module_class_keys',
'get_linear_prepack_op_for_dtype',
'get_qconv_prepack_op',
'get_qconv_op',
'get_new_attr_name_with_prefix',
'graph_module_from_producer_nodes',
'assert_and_get_unique_device',
'create_getattr_from_value',
'create_qparam_nodes',
'all_node_args_have_no_tensors',
'node_return_type_is_int',
'node_bool_tensor_arg_indexes',
'is_get_tensor_info_node',
'maybe_get_next_module'
]
self._test_function_import('fx.utils', function_list)