From 31f311a816c026bbfca622d6121d6a7fab44260d Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 14 Apr 2023 19:00:10 -0700 Subject: [PATCH] [PT2E][Quantization] Refactor Quantizer and QNNPACKQuantizer (#99063) This diff renames quantization spec/config and operator config. It moves these datastructures to base quantizer. Base quantizer API now has get_supported_operators that returns list of patterns that a quantizer quantizes. There are two choices being debated for how to convey to user what a particular quantizer will quantize. 1. Modules. We just convey what nn.Modules will be quantized. Of course that does not mean that equivalent functional variants wont be quantized, however for simplifity we just use nn.Module. If certain ops are quatnzied in fused manner then that will considered internal details. Pros and cons of this approach pros: - Simple. Only nn Modules are listed. - User does not have to see fusion patterns. Cons: - confusing perhaps because it is not clear if supported = nn.Conv2d also means that the quantizer supported functional.conv2d - Hiding fusion pattern means user has no say in not fusing. Meaning if conv2d + relu is fused and user configures to quantize only conv, quantizer will also quantize the following relu as if conv2d + relu are fused. 2. Patterns. Be explicit about what is supported and enumerate all possible compbinations. Pros: - it is very clear what quantizer will do. no surprises. Cons: - It is not simple to parse. - It can be argued taht fusion is internal detail of the quantizer. So some quantizer implementation may chose to expose fusion patterns, while others may not and may not even provide any configurability. One option is to move set_supported_operators/modules out of base quantizer and let each quantizer define its own way of communicating what is supported. Issue with this is that when we want to "Compose" multiple quantizers there is no way for user to define the order of composition if user does not know what a quantizer supports. For exampl quantizer A may quantizer conv + relu while B only conv, but B's implementation is fast. In that case you may compose (B, A) such B quantizes conv and A quantizes relu. Not knowning what A and B support, makes such composition harder Differential Revision: [D44895547](https://our.internmc.facebook.com/intern/diff/D44895547/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44895547/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/99063 Approved by: https://github.com/jerryzh168 --- test/quantization/fx/test_quantize_pt2e.py | 288 +++++++--- .../quantization/_pt2e/quantizer/__init__.py | 4 +- .../_pt2e/quantizer/qnnpack_quantizer.py | 491 ++++++++++-------- .../quantization/_pt2e/quantizer/quantizer.py | 90 +++- 4 files changed, 570 insertions(+), 303 deletions(-) diff --git a/test/quantization/fx/test_quantize_pt2e.py b/test/quantization/fx/test_quantize_pt2e.py index bf350c38869..b99040c24fb 100644 --- a/test/quantization/fx/test_quantize_pt2e.py +++ b/test/quantization/fx/test_quantize_pt2e.py @@ -1,39 +1,46 @@ # Owner(s): ["oncall: quantization"] +import copy +import itertools +from typing import List + import torch -import torch.nn as nn import torch._dynamo as torchdynamo +import torch.nn as nn +from torch._inductor.compile_fx import compile_fx +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping +from torch.ao.quantization._pt2e.quantizer import ( + OperatorConfig, + QNNPackQuantizer, + Quantizer, +) +from torch.ao.quantization._quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_pt2e_quantizer, +) +from torch.ao.quantization.backend_config import get_qnnpack_backend_config +from torch.ao.quantization.backend_config._qnnpack_pt2e import ( + get_qnnpack_pt2e_backend_config, +) +from torch.ao.quantization.backend_config._x86_inductor_pt2e import ( + get_x86_inductor_pt2e_backend_config, +) +from torch.ao.quantization.backend_config.x86 import get_x86_backend_config +from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig +from torch.ao.quantization.quantize_fx import ( + convert_fx, + convert_to_reference_fx, + prepare_fx, +) from torch.testing._internal.common_quantization import ( + NodeSpec as ns, QuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, skipIfNoX86, ) -from torch.testing._internal.common_quantization import NodeSpec as ns -from torch.testing._internal.common_quantized import ( - override_quantized_engine, -) -from torch.ao.quantization import ( - get_default_qconfig, - QConfigMapping, - observer, -) -from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig -from torch.ao.quantization.backend_config import ( - get_qnnpack_backend_config, -) -from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config -from torch.ao.quantization.backend_config._x86_inductor_pt2e import get_x86_inductor_pt2e_backend_config -from torch.ao.quantization.backend_config.x86 import get_x86_backend_config -from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx, convert_fx -from torch.ao.quantization._pt2e.quantizer import Quantizer -from torch.ao.quantization._pt2e.quantizer import QNNPackQuantizer -from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e, prepare_pt2e_quantizer -from torch.ao.ns.fx.utils import ( - compute_sqnr, -) -import copy -import itertools -from torch._inductor.compile_fx import compile_fx +from torch.testing._internal.common_quantized import override_quantized_engine @skipIfNoQNNPACK @@ -62,8 +69,9 @@ class TestQuantizePT2E(QuantizationTestCase): ) qconfig = get_default_qconfig("qnnpack") - qconfig_mapping = QConfigMapping().set_global(qconfig) \ - .set_module_name("conv2", None) + qconfig_mapping = ( + QConfigMapping().set_global(qconfig).set_module_name("conv2", None) + ) backend_config = get_qnnpack_pt2e_backend_config() m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config) m(*example_inputs) @@ -74,7 +82,9 @@ class TestQuantizePT2E(QuantizationTestCase): node_occurrence = { # two for input of the first conv, one for output for the first conv ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ): 3, } node_list = [ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), @@ -84,7 +94,10 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.aten.convolution.default), ] self.checkGraphModuleNodes( - m, expected_node_list=node_list, expected_node_occurrence=node_occurrence) + m, + expected_node_list=node_list, + expected_node_occurrence=node_occurrence, + ) def test_qconfig_module_type(self): class M(torch.nn.Module): @@ -122,7 +135,9 @@ class TestQuantizePT2E(QuantizationTestCase): node_occurrence = { # two for input and weight of the conv, one for output for the conv ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ): 3, } node_list = [ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), @@ -145,17 +160,28 @@ class TestQuantizePT2E(QuantizationTestCase): class BackendAQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: _DEFAULT_TARGET_DTYPE_INFO = { - "input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float), - "output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float), + "input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args( + dtype=torch.float + ), + "output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args( + dtype=torch.float + ), } for node in model.graph.nodes: - node.meta["target_dtype_info"] = copy.deepcopy(_DEFAULT_TARGET_DTYPE_INFO) + node.meta["target_dtype_info"] = copy.deepcopy( + _DEFAULT_TARGET_DTYPE_INFO + ) for node in model.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.convolution.default: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.convolution.default + ): node.meta["target_dtype_info"] = { "input_act_obs_or_fq_ctr": observer.default_observer, "weight_obs_or_fq_ctr": observer.default_weight_observer, - "bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float), + "bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args( + dtype=torch.float + ), "output_act_obs_or_fq_ctr": observer.default_observer, "weight_index": 1, "bias_index": 2, @@ -164,6 +190,10 @@ class TestQuantizePT2E(QuantizationTestCase): def validate(self, model: torch.fx.GraphModule) -> None: pass + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + pass + m = M().eval() example_inputs = (torch.randn(1, 3, 5, 5),) @@ -189,7 +219,8 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), ] self.checkGraphModuleNodes( - m, expected_node_list=node_list, expected_node_occurrence=node_occurrence) + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) def test_qnnpack_quantizer_conv(self): class M(torch.nn.Module): @@ -201,9 +232,12 @@ class TestQuantizePT2E(QuantizationTestCase): return self.conv(x) import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + quantizer = QNNPackQuantizer() - operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec() - quantizer.set_global(operator_spec) + operator_config = ( + qq.get_default_per_channel_symmetric_qnnpack_quantization_config() + ) + quantizer.set_global(operator_config) m = M().eval() example_inputs = (torch.randn(1, 3, 5, 5),) @@ -232,7 +266,8 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), ] self.checkGraphModuleNodes( - m, expected_node_list=node_list, expected_node_occurrence=node_occurrence) + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) def test_qnnpack_quantizer_obs_sharing_ops(self): class M(torch.nn.Module): @@ -250,9 +285,12 @@ class TestQuantizePT2E(QuantizationTestCase): return x import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + quantizer = QNNPackQuantizer() - operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec() - quantizer.set_global(operator_spec) + operator_config = ( + qq.get_default_per_channel_symmetric_qnnpack_quantization_config() + ) + quantizer.set_global(operator_config) m = M().eval() example_inputs = (torch.randn(1, 3, 5, 5),) @@ -279,11 +317,9 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel), ns.call_function(torch.ops.aten.convolution.default), ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), ns.call_function(torch.ops.aten.mean.dim), ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), ns.call_function(torch.ops.aten.hardtanh.default), ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), @@ -293,7 +329,8 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), ] self.checkGraphModuleNodes( - m, expected_node_list=node_list, expected_node_occurrence=node_occurrence) + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) def test_rearrange_weight_observer_for_decomposed_linear(self): """ @@ -305,6 +342,7 @@ class TestQuantizePT2E(QuantizationTestCase): weight - observer - t \ input - observer - addmm/mm """ + class M(torch.nn.Module): def __init__(self, with_bias, use_relu): super().__init__() @@ -331,7 +369,7 @@ class TestQuantizePT2E(QuantizationTestCase): tracing_mode="real", ) - qconfig = get_default_qconfig('qnnpack') + qconfig = get_default_qconfig("qnnpack") qconfig_mapping = QConfigMapping().set_global(qconfig) backend_config = get_qnnpack_pt2e_backend_config() m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config) @@ -339,12 +377,19 @@ class TestQuantizePT2E(QuantizationTestCase): # 1. Check graph nodes: # - args[0] of t should be the weight observer # - args[-1] of addmm/mm should be t - error_msg = 'Weight observer is not correctly rearranged for decomposed linear' + error_msg = ( + "Weight observer is not correctly rearranged for decomposed linear" + ) for node in m.graph.nodes: if node.target == torch.ops.aten.t.default: target = node.args[0].target - self.assertTrue(isinstance(getattr(m, target), observer.ObserverBase), error_msg) - elif node.target in (torch.ops.aten.addmm.default, torch.ops.aten.mm.default): + self.assertTrue( + isinstance(getattr(m, target), observer.ObserverBase), error_msg + ) + elif node.target in ( + torch.ops.aten.addmm.default, + torch.ops.aten.mm.default, + ): target = node.args[-1].target self.assertTrue(target == torch.ops.aten.t.default, error_msg) @@ -380,7 +425,9 @@ class TestQuantizePT2E(QuantizationTestCase): node_occurrence = { ns.call_function(torch.ops.aten.convolution.default): 1, - ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 1, + ns.call_function( + torch.ops.aten._native_batch_norm_legit_no_training.default + ): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) @@ -394,10 +441,13 @@ class TestQuantizePT2E(QuantizationTestCase): # make sure bn is fused into conv node_occurrence = { ns.call_function(torch.ops.aten.convolution.default): 1, - ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 0, + ns.call_function( + torch.ops.aten._native_batch_norm_legit_no_training.default + ): 0, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + @skipIfNoQNNPACK class TestQuantizePT2EX86Inductor(QuantizationTestCase): @skipIfNoX86 @@ -417,7 +467,9 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase): inplace_relu_list = [True, False] with override_quantized_engine("x86"): with torch.no_grad(): - for use_relu, inplace_relu in itertools.product(use_relu_list, inplace_relu_list): + for use_relu, inplace_relu in itertools.product( + use_relu_list, inplace_relu_list + ): m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval() example_inputs = (torch.randn(2, 3, 4, 4),) # program capture @@ -433,7 +485,9 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase): qconfig = get_default_qconfig("x86") qconfig_mapping = QConfigMapping().set_global(qconfig) backend_config = get_x86_inductor_pt2e_backend_config() - prepare_module = prepare_pt2e(export_module, qconfig_mapping, example_inputs, backend_config) + prepare_module = prepare_pt2e( + export_module, qconfig_mapping, example_inputs, backend_config + ) prepare_module(*example_inputs) convert_module = convert_pt2e(prepare_module) convert_module(*example_inputs) @@ -441,38 +495,75 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase): # Fake quant should only be inserted at start and end node_occurrence = { # one for input and weight of the conv, one for output for the conv - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2, - ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2, + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ): 2, } if use_relu: node_list = [ - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ), ns.call_function(torch.ops.aten.convolution.default), - ns.call_function(torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default), - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function( + torch.ops.aten.relu_.default + if inplace_relu + else torch.ops.aten.relu.default + ), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ), ] else: node_list = [ - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ), ns.call_function(torch.ops.aten.convolution.default), - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor + ), ] - self.checkGraphModuleNodes(convert_module, - expected_node_occurrence=node_occurrence, - expected_node_list=node_list) + self.checkGraphModuleNodes( + convert_module, + expected_node_occurrence=node_occurrence, + expected_node_list=node_list, + ) # Step1: Ref result in 1.X fx path backend_config_1_x = get_x86_backend_config() m_copy = copy.deepcopy(m) - m_prepare_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config_1_x) + m_prepare_fx = prepare_fx( + m_copy, + qconfig_mapping, + example_inputs, + backend_config=backend_config_1_x, + ) after_prepare_result_fx = m_prepare_fx(*example_inputs) - m_convert_fx = convert_fx(m_prepare_fx, backend_config=backend_config_1_x) + m_convert_fx = convert_fx( + m_prepare_fx, backend_config=backend_config_1_x + ) ref_result = m_convert_fx(*example_inputs) # Step2: Start to lowering into Inductor @@ -483,11 +574,13 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase): inductor_res = run(*example_inputs) self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2) + class TestQuantizePT2EModels(QuantizationTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK def test_resnet18(self): import torchvision + with override_quantized_engine("qnnpack"): example_inputs = (torch.randn(1, 3, 224, 224),) m = torchvision.models.resnet18().eval() @@ -510,7 +603,9 @@ class TestQuantizePT2EModels(QuantizationTestCase): # checking that we inserted observers correctly for maxpool operator (input and # output share observer instance) - self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2)) + self.assertEqual( + id(m.activation_post_process_3), id(m.activation_post_process_2) + ) after_prepare_result = m(*example_inputs) m = convert_pt2e(m) @@ -518,7 +613,9 @@ class TestQuantizePT2EModels(QuantizationTestCase): # comparing with existing fx graph mode quantization reference flow backend_config = get_qnnpack_backend_config() - m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config) + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) after_prepare_result_fx = m_fx(*example_inputs) m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) @@ -526,17 +623,24 @@ class TestQuantizePT2EModels(QuantizationTestCase): # the result matches exactly after prepare self.assertEqual(after_prepare_result, after_prepare_result_fx) - self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf"))) + self.assertEqual( + compute_sqnr(after_prepare_result, after_prepare_result_fx), + torch.tensor(float("inf")), + ) # there are slight differences after convert due to different implementations # of quant/dequant - self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1) - self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35) - + self.assertTrue( + torch.max(after_quant_result - after_quant_result_fx) < 1e-1 + ) + self.assertTrue( + compute_sqnr(after_quant_result, after_quant_result_fx) > 35 + ) @skip_if_no_torchvision @skipIfNoQNNPACK def test_resnet18_with_quantizer_api(self): import torchvision + with override_quantized_engine("qnnpack"): example_inputs = (torch.randn(1, 3, 224, 224),) m = torchvision.models.resnet18().eval() @@ -551,13 +655,18 @@ class TestQuantizePT2EModels(QuantizationTestCase): before_fusion_result = m(*example_inputs) import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + quantizer = QNNPackQuantizer() - operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec() - quantizer.set_global(operator_spec) + operator_config = ( + qq.get_default_per_channel_symmetric_qnnpack_quantization_config() + ) + quantizer.set_global(operator_config) m = prepare_pt2e_quantizer(m, quantizer) # checking that we inserted observers correctly for maxpool operator (input and # output share observer instance) - self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2)) + self.assertEqual( + id(m.activation_post_process_3), id(m.activation_post_process_2) + ) after_prepare_result = m(*example_inputs) m = convert_pt2e(m) @@ -567,7 +676,9 @@ class TestQuantizePT2EModels(QuantizationTestCase): qconfig = default_per_channel_symmetric_qnnpack_qconfig qconfig_mapping = QConfigMapping().set_global(qconfig) backend_config = get_qnnpack_backend_config() - m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config) + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) after_prepare_result_fx = m_fx(*example_inputs) m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) @@ -579,8 +690,15 @@ class TestQuantizePT2EModels(QuantizationTestCase): # but we can still manully inspect the printed observers to make sure # it matches self.assertEqual(after_prepare_result, after_prepare_result_fx) - self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf"))) + self.assertEqual( + compute_sqnr(after_prepare_result, after_prepare_result_fx), + torch.tensor(float("inf")), + ) # there are slight differences after convert due to different implementations # of quant/dequant - self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1) - self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35) + self.assertTrue( + torch.max(after_quant_result - after_quant_result_fx) < 1e-1 + ) + self.assertTrue( + compute_sqnr(after_quant_result, after_quant_result_fx) > 35 + ) diff --git a/torch/ao/quantization/_pt2e/quantizer/__init__.py b/torch/ao/quantization/_pt2e/quantizer/__init__.py index e9e3ea62416..ccd31219019 100644 --- a/torch/ao/quantization/_pt2e/quantizer/__init__.py +++ b/torch/ao/quantization/_pt2e/quantizer/__init__.py @@ -1,7 +1,7 @@ -from .quantizer import Quantizer from .qnnpack_quantizer import QNNPackQuantizer +from .quantizer import OperatorConfig, Quantizer __all__ = [ - "Quantizer" + "Quantizer", "QNNPackQuantizer", ] diff --git a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py index 2f5352ba9cd..2e521c08d23 100644 --- a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py +++ b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py @@ -1,95 +1,80 @@ from __future__ import annotations -from .quantizer import Quantizer import copy -from dataclasses import dataclass -from typing import List, NamedTuple, Optional, Set, Dict, Callable +import operator +from typing import Callable, Dict, List, Optional, Set + +import torch +import torch.nn.functional as F from torch.ao.quantization.observer import ( - PlaceholderObserver, HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, + PlaceholderObserver, ) -import torch -import operator from torch.fx import Node +from .quantizer import ( + OperatorConfig, + OperatorPatternType, + QuantizationConfig, + QuantizationSpec, + Quantizer, +) + __all__ = [ "QNNPackQuantizer", - "get_default_symmetric_qnnpack_operator_spec", - "get_default_per_channel_symmetric_qnnpack_operator_spec", + "get_default_symmetric_qnnpack_quantization_config", + "get_default_per_channel_symmetric_qnnpack_quantization_config", ] -# TODO: maybe remove torch.float32 -SUPPORTED_DTYPES = [torch.uint8, torch.int8, torch.int32, torch.float16, torch.float32] -SUPPORTED_QSCHEMES = [ - torch.per_tensor_affine, - torch.per_tensor_symmetric, - torch.per_channel_affine, - torch.per_channel_symmetric, - torch.per_channel_affine_float_qparams, -] -@dataclass(eq=True, frozen=True) -class TensorSpec: - dtype: torch.dtype - is_dynamic: bool = False - quant_min: Optional[int] = None - quant_max: Optional[int] = None - qscheme: Optional[torch.qscheme] = None - ch_axis: Optional[int] = None - - def __post_init__(self): - # check dtype is one of the supported types - if self.dtype not in SUPPORTED_DTYPES: - raise TypeError(f"Unsupported dtype {self.dtype}.") - - # quant_min must be less than quant_max - if self.quant_min is not None and self.quant_max is not None and self.quant_min > self.quant_max: - raise ValueError( - f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." - ) - - # check qscheme is on of the supported ones - if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES: - raise ValueError(f"Unsupported qscheme {self.qscheme}.") - - # ch_axis must be less than the number of channels - # but no way to check here. Just check that it is not < 0. - if self.ch_axis is not None and self.ch_axis < 0: - raise ValueError("Ch_axis is < 0.") - - -OperatorSpec = NamedTuple( - "OperatorSpec", [("activation", TensorSpec), ("weight", TensorSpec), ("bias", TensorSpec)] -) - -SpecAndOperators = NamedTuple( - "SpecAndOperators", - [("operator_spec", OperatorSpec), ("operators", List[str])] -) - - -def supported_symmetric_quantized_operators() -> List[str]: - supported_operators = ["conv2d", "linear", "add", "maxpool2d", "hardtanh", "mean", "adaptive_avgpool2d"] +def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: + supported_operators: Dict[str, List[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "maxpool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]], + "hardtanh": [[torch.nn.Hardtanh], [F.hardtanh]], + "mean": [[torch.mean]], + "adaptive_avgpool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } return copy.deepcopy(supported_operators) -def get_supported_symmetric_quantized_spec_and_operators() -> List[SpecAndOperators]: - supported_spec_and_operators: List[SpecAndOperators] = [] - for operator_spec in [get_default_symmetric_qnnpack_operator_spec(), get_default_per_channel_symmetric_qnnpack_operator_spec()]: - ops = supported_symmetric_quantized_operators() - supported_spec_and_operators.append(SpecAndOperators(operator_spec, ops)) - return copy.deepcopy(supported_spec_and_operators) -def get_default_symmetric_qnnpack_operator_spec(): - act_tensor_spec = TensorSpec( +def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: + supported_config_and_operators: List[OperatorConfig] = [] + for quantization_config in [ + get_default_symmetric_qnnpack_quantization_config(), + get_default_per_channel_symmetric_qnnpack_quantization_config(), + ]: + ops = supported_symmetric_quantized_operators() + for op_string, pattern_list in ops.items(): + supported_config_and_operators.append( + OperatorConfig(quantization_config, pattern_list) + ) + return copy.deepcopy(supported_config_and_operators) + + +def get_default_symmetric_qnnpack_quantization_config(): + act_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, is_dynamic=False, ) - weight_tensor_spec = TensorSpec( + weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-127, quant_max=127, @@ -97,19 +82,22 @@ def get_default_symmetric_qnnpack_operator_spec(): ch_axis=1, is_dynamic=False, ) - bias_tensor_spec = TensorSpec(dtype=torch.float) - operator_spec = OperatorSpec(act_tensor_spec, weight_tensor_spec, bias_tensor_spec) - return operator_spec + bias_quantization_spec = QuantizationSpec(dtype=torch.float) + quantization_config = QuantizationConfig( + act_quantization_spec, weight_quantization_spec, bias_quantization_spec + ) + return quantization_config -def get_default_per_channel_symmetric_qnnpack_operator_spec(): - act_tensor_spec = TensorSpec( + +def get_default_per_channel_symmetric_qnnpack_quantization_config(): + act_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, is_dynamic=False, ) - weight_tensor_spec = TensorSpec( + weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-127, quant_max=127, @@ -117,27 +105,16 @@ def get_default_per_channel_symmetric_qnnpack_operator_spec(): ch_axis=1, is_dynamic=False, ) - bias_tensor_spec = TensorSpec(dtype=torch.float) - operator_spec = OperatorSpec(act_tensor_spec, weight_tensor_spec, bias_tensor_spec) - return operator_spec + bias_quantization_spec = QuantizationSpec(dtype=torch.float) + quantization_config = QuantizationConfig( + act_quantization_spec, weight_quantization_spec, bias_quantization_spec + ) + return quantization_config -def get_supported_spec_and_operators() -> List[SpecAndOperators]: - return get_supported_symmetric_quantized_spec_and_operators() -class OperatorSpecConfig: +def get_supported_config_and_operators() -> List[OperatorConfig]: + return get_supported_symmetric_config_and_operators() - def __init__(self): - super().__init__() - self.global_spec: Optional[OperatorSpec] = None - self.operator_type_specs: Dict[str, Optional[OperatorSpec]] = {} - - def set_global(self, operator_spec: Optional[OperatorSpec]) -> OperatorSpecConfig: - self.global_spec = operator_spec - return self - - def set_operator_type(self, operator_type: str, operator_spec: Optional[OperatorSpec]) -> OperatorSpecConfig: - self.operator_type_specs[operator_type] = operator_spec - return self # TODO: add support for torch dtype in quant code base # this includes observers and prepare/convert code @@ -147,61 +124,76 @@ _TORCH_DTYPE_TO_QDTYPE = { torch.int32: torch.qint32, torch.float16: torch.float16, } -def _get_act_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]): - if operator_spec is None: + + +def _get_act_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: return None - assert operator_spec is not None - tensor_spec: TensorSpec = operator_spec.activation - qdtype = _TORCH_DTYPE_TO_QDTYPE[tensor_spec.dtype] - assert tensor_spec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric] - if not tensor_spec.is_dynamic: + assert quantization_config is not None + quantization_spec: QuantizationSpec = quantization_config.activation + qdtype = _TORCH_DTYPE_TO_QDTYPE[quantization_spec.dtype] + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + if not quantization_spec.is_dynamic: return HistogramObserver.with_args( dtype=qdtype, - quant_min=tensor_spec.quant_min, - quant_max=tensor_spec.quant_max, + quant_min=quantization_spec.quant_min, + quant_max=quantization_spec.quant_max, reduce_range=False, - eps=2**-12 + eps=2**-12, ) else: # TODO: extend this helper function to support dynamic quantization - raise Exception("Unsupported tensor_spec for activation: {}".format(tensor_spec)) - -def _get_weight_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]): - if operator_spec is None: - return None - assert operator_spec is not None - tensor_spec: TensorSpec = operator_spec.weight - qdtype = _TORCH_DTYPE_TO_QDTYPE[tensor_spec.dtype] - if tensor_spec.qscheme == torch.per_tensor_symmetric: - return MinMaxObserver.with_args( - qscheme=tensor_spec.qscheme, - dtype=qdtype, - quant_min=tensor_spec.quant_min, - quant_max=tensor_spec.quant_max, - eps=2**-12 + raise Exception( + "Unsupported quantization_spec for activation: {}".format(quantization_spec) ) - elif tensor_spec.qscheme == torch.per_channel_symmetric: - return PerChannelMinMaxObserver.with_args( - qscheme=tensor_spec.qscheme, + + +def _get_weight_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + quantization_spec: QuantizationSpec = quantization_config.weight + qdtype = _TORCH_DTYPE_TO_QDTYPE[quantization_spec.dtype] + if quantization_spec.qscheme == torch.per_tensor_symmetric: + return MinMaxObserver.with_args( + qscheme=quantization_spec.qscheme, dtype=qdtype, - quant_min=tensor_spec.quant_min, - quant_max=tensor_spec.quant_max, - eps=2**-12 + quant_min=quantization_spec.quant_min, + quant_max=quantization_spec.quant_max, + eps=2**-12, + ) + elif quantization_spec.qscheme == torch.per_channel_symmetric: + return PerChannelMinMaxObserver.with_args( + qscheme=quantization_spec.qscheme, + dtype=qdtype, + quant_min=quantization_spec.quant_min, + quant_max=quantization_spec.quant_max, + eps=2**-12, ) else: - raise Exception("Unsupported tensor_spec for weight: {}".format(tensor_spec)) + raise Exception( + "Unsupported quantization_spec for weight: {}".format(quantization_spec) + ) -def _get_bias_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]): - if operator_spec is None: + +def _get_bias_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: return None - assert operator_spec is not None - tensor_spec: TensorSpec = operator_spec.bias - assert tensor_spec.dtype == torch.float, "Only float dtype for bias is supported for bias right now" - return PlaceholderObserver.with_args(dtype=tensor_spec.dtype) + assert quantization_config is not None + quantization_spec: QuantizationSpec = quantization_config.bias + assert ( + quantization_spec.dtype == torch.float + ), "Only float dtype for bias is supported for bias right now" + return PlaceholderObserver.with_args(dtype=quantization_spec.dtype) + def _get_default_obs_or_fq_ctr(): return PlaceholderObserver.with_args(dtype=torch.float) + def _is_annotated(nodes: List[Node]): """ Given a list of nodes (that represents an operator pattern), @@ -210,96 +202,106 @@ def _is_annotated(nodes: List[Node]): """ annotated = False for node in nodes: - annotated = annotated or ("target_dtype_info" in node.meta and node.meta["target_dtype_info"].get("_annotated", False)) + annotated = annotated or ( + "target_dtype_info" in node.meta + and node.meta["target_dtype_info"].get("_annotated", False) + ) return annotated + class QNNPackQuantizer(Quantizer): - supported_spec_and_operators = get_supported_spec_and_operators() + supported_config_and_operators = get_supported_config_and_operators() def __init__(self): super().__init__() - self.operator_spec_config = OperatorSpecConfig() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} @classmethod - def get_supported_operator_specs(cls) -> List[OperatorSpec]: - op_specs: Set[OperatorSpec] = set({}) - for spec, _ in cls.supported_spec_and_operators: - op_specs.add(spec) - return list(op_specs) + def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: + op_configs: Set[QuantizationConfig] = set({}) + for spec, _ in cls.supported_config_and_operators: + op_configs.add(spec) + return list(op_configs) @classmethod - def get_supported_operator_for_operator_spec(cls, operator_spec: Optional[OperatorSpec]) -> List[str]: - if operator_spec is None: + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> List[OperatorPatternType]: + if quantization_config is None: all_ops = [] - for _, ops in cls.supported_spec_and_operators: + for _, ops in cls.supported_config_and_operators: all_ops.extend(ops) return all_ops - for spec, ops in cls.supported_spec_and_operators: + for config, ops in cls.supported_config_and_operators: # note: this assumes each entry in cls.supported_spec_and_operators # corresponds to one spec, e.g. we don't have # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] # where the first and second entry have the same spec but did not # merge the op list - if spec == operator_spec: + if config == quantization_config: return ops return [] - def set_global(self, operator_spec: Optional[OperatorSpec]) -> QNNPackQuantizer: - self.operator_spec_config.set_global(operator_spec) + def set_global( + self, quantization_config: Optional[QuantizationConfig] + ) -> QNNPackQuantizer: + self.global_config = quantization_config return self - def set_spec_for_operator_type( - self, operator_type: str, operator_spec: Optional[OperatorSpec] + def set_config_for_operator_type( + self, operator_type: str, quantization_config: Optional[QuantizationConfig] ) -> QNNPackQuantizer: - self.operator_spec_config.set_operator_type(operator_type, operator_spec) + self.operator_type_config[operator_type] = quantization_config return self def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ just handling global spec for now - """ - global_spec = self.operator_spec_config.global_spec - ops = self.get_supported_operator_for_operator_spec(global_spec) + """just handling global spec for now""" + global_config = self.global_config + ops = self.get_supported_operator_for_quantization_config(global_config) # annotate the nodes from last to first since the matching is in the reversed order # and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv) # and we will mark the matched node with "_annoated" so fusion operator pattern # can take precedence over single operator pattern in this way for node in reversed(model.graph.nodes): - for op in ops: - if op == "conv2d": - self._annotate_conv2d_relu(node, global_spec) - self._annotate_conv2d(node, global_spec) - elif op == "linear": - self._annotate_linear(node, global_spec) - elif op == "maxpool2d": - self._annotate_maxpool2d(node, global_spec) - elif op == "add": - self._annotate_add_relu(node, global_spec) - self._annotate_add(node, global_spec) - elif op == "hardtanh": - self._annotate_hardtanh(node, global_spec) - elif op == "mean": - self._annotate_mean(node, global_spec) - elif op == "adaptive_avgpool2d": - self._annotate_adaptive_avg_pool2d(node, global_spec) + # one improvement is to register node annotators for each + # supported op type. + self._annotate_conv2d_relu(node, global_config) + self._annotate_conv2d(node, global_config) + self._annotate_linear(node, global_config) + self._annotate_maxpool2d(node, global_config) + self._annotate_add_relu(node, global_config) + self._annotate_add(node, global_config) + self._annotate_hardtanh(node, global_config) + self._annotate_mean(node, global_config) + self._annotate_adaptive_avg_pool2d(node, global_config) return model - def _annotate_conv2d_relu(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: - if node.op != "call_function" or node.target not in [torch.ops.aten.relu_.default, torch.ops.aten.relu.default]: + def _annotate_conv2d_relu( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu_.default, + torch.ops.aten.relu.default, + ]: return relu_node = node conv_node = relu_node.args[0] assert isinstance(conv_node, Node) - if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default: + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): return if _is_annotated([relu_node, conv_node]): return conv_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), - "weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec), - "bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), + "weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config), + "bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config), "output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), # TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set "weight_index": 1, @@ -309,22 +311,27 @@ class QNNPackQuantizer(Quantizer): } relu_node.meta["target_dtype_info"] = { "input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "_annotated": True, } - def _annotate_conv2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: + def _annotate_conv2d( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: conv_node = node - if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default: + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): return # skip annotation if it is already annotated if _is_annotated([conv_node]): return conv_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), - "weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec), - "bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), + "weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config), + "bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), # TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set "weight_index": 1, # TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set @@ -332,13 +339,21 @@ class QNNPackQuantizer(Quantizer): "_annotated": True, } - def _annotate_linear(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: + def _annotate_linear( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: addmm_node = node - if addmm_node.op != "call_function" or addmm_node.target != torch.ops.aten.addmm.default: + if ( + addmm_node.op != "call_function" + or addmm_node.target != torch.ops.aten.addmm.default + ): return view_node = addmm_node.args[1] assert isinstance(view_node, Node) - if view_node.op != "call_function" or view_node.target != torch.ops.aten.view.default: + if ( + view_node.op != "call_function" + or view_node.target != torch.ops.aten.view.default + ): return t_node = addmm_node.args[2] assert isinstance(t_node, Node) @@ -349,106 +364,154 @@ class QNNPackQuantizer(Quantizer): # bias and output act addmm_node.meta["target_dtype_info"] = { - "bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec), + "bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config), "input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "bias_index": 0, "_annotated": True, } # input act view_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), "_annotated": True, } # weight t_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config), "output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), "_annotated": True, } - def _annotate_maxpool2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: - if node.op != "call_function" or node.target != operator.getitem or node.args[1] != 0: + def _annotate_maxpool2d( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: + if ( + node.op != "call_function" + or node.target != operator.getitem + or node.args[1] != 0 + ): return getitem_node = node maxpool_node = getitem_node.args[0] assert isinstance(maxpool_node, Node) - if maxpool_node.op != "call_function" or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default: + if ( + maxpool_node.op != "call_function" + or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default + ): return if _is_annotated([getitem_node, maxpool_node]): return maxpool_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), "_annotated": True, } getitem_node.meta["target_dtype_info"] = { "input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "input_output_share_observers": True, "_annotated": True, } - def _annotate_input_out_obs_sharing_op(self, op: Callable, node: Node, operator_spec: Optional[OperatorSpec]) -> None: + def _annotate_input_out_obs_sharing_op( + self, + op: Callable, + node: Node, + quantization_config: Optional[QuantizationConfig], + ) -> None: io_obs_sharing_node = node - if io_obs_sharing_node.op != "call_function" or io_obs_sharing_node.target != op: + if ( + io_obs_sharing_node.op != "call_function" + or io_obs_sharing_node.target != op + ): return if _is_annotated([io_obs_sharing_node]): return io_obs_sharing_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "input_output_share_observers": True, "_annotated": True, } - def _annotate_hardtanh(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: - self._annotate_input_out_obs_sharing_op(torch.ops.aten.hardtanh.default, node, operator_spec) + def _annotate_hardtanh( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: + self._annotate_input_out_obs_sharing_op( + torch.ops.aten.hardtanh.default, node, quantization_config + ) - def _annotate_mean(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: - self._annotate_input_out_obs_sharing_op(torch.ops.aten.mean.default, node, operator_spec) - self._annotate_input_out_obs_sharing_op(torch.ops.aten.mean.dim, node, operator_spec) + def _annotate_mean( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: + self._annotate_input_out_obs_sharing_op( + torch.ops.aten.mean.default, node, quantization_config + ) + self._annotate_input_out_obs_sharing_op( + torch.ops.aten.mean.dim, node, quantization_config + ) - def _annotate_adaptive_avg_pool2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: - self._annotate_input_out_obs_sharing_op(torch.ops.aten.adaptive_avg_pool2d.default, node, operator_spec) + def _annotate_adaptive_avg_pool2d( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: + self._annotate_input_out_obs_sharing_op( + torch.ops.aten.adaptive_avg_pool2d.default, node, quantization_config + ) - def _annotate_add_relu(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: - if node.op != "call_function" or node.target not in [torch.ops.aten.relu_.default, torch.ops.aten.relu.default]: + def _annotate_add_relu( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu_.default, + torch.ops.aten.relu.default, + ]: return relu_node = node add_node = relu_node.args[0] assert isinstance(add_node, Node) - if add_node.op != "call_function" or add_node.target not in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]: + if add_node.op != "call_function" or add_node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: return if _is_annotated([relu_node, add_node]): return add_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), "_annotated": True, } relu_node.meta["target_dtype_info"] = { "input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "_annotated": True, } - def _annotate_add(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None: + def _annotate_add( + self, node: Node, quantization_config: Optional[QuantizationConfig] + ) -> None: add_node = node - if add_node.op != "call_function" or add_node.target not in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]: + if add_node.op != "call_function" or add_node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: return if _is_annotated([add_node]): return add_node.meta["target_dtype_info"] = { - "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), - "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec), + "input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), + "output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config), "_annotated": True, } def validate(self, model: torch.fx.GraphModule) -> None: pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/torch/ao/quantization/_pt2e/quantizer/quantizer.py b/torch/ao/quantization/_pt2e/quantizer/quantizer.py index edc193fd287..cffa335574f 100644 --- a/torch/ao/quantization/_pt2e/quantizer/quantizer.py +++ b/torch/ao/quantization/_pt2e/quantizer/quantizer.py @@ -1,11 +1,90 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, List, NamedTuple, Optional + import torch __all__ = [ "Quantizer", ] +# TODO: maybe remove torch.float32 +SUPPORTED_DTYPES = [torch.uint8, torch.int8, torch.int32, torch.float16, torch.float32] +SUPPORTED_QSCHEMES = [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, +] + + +@dataclass(eq=True, frozen=True) +class QuantizationSpec: + dtype: torch.dtype + is_dynamic: bool = False + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + + def __post_init__(self): + # check dtype is one of the supported types + if self.dtype not in SUPPORTED_DTYPES: + raise TypeError(f"Unsupported dtype {self.dtype}.") + + # quant_min must be less than quant_max + if ( + self.quant_min is not None + and self.quant_max is not None + and self.quant_min > self.quant_max + ): + raise ValueError( + f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." + ) + + # check qscheme is on of the supported ones + if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES: + raise ValueError(f"Unsupported qscheme {self.qscheme}.") + + # ch_axis must be less than the number of channels + # but no way to check here. Just check that it is not < 0. + if self.ch_axis is not None and self.ch_axis < 0: + raise ValueError("Ch_axis is < 0.") + + +# In the absence of better name, just winging it with QuantizationConfig +QuantizationConfig = NamedTuple( + "QuantizationConfig", + [ + ("activation", QuantizationSpec), + ("weight", QuantizationSpec), + ("bias", QuantizationSpec), + ], +) + +OperatorPatternType = List[Callable] + +OperatorConfig = NamedTuple( + "OperatorConfig", + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + [ + ("config", QuantizationConfig), + ( + "operators", + List[OperatorPatternType], + ), + ], +) + + class Quantizer(ABC): # annotate nodes in the graph with observer or fake quant constructors @@ -18,3 +97,10 @@ class Quantizer(ABC): @abstractmethod def validate(self, model: torch.fx.GraphModule) -> None: pass + + # annotate nodes in the graph with observer or fake quant constructors + # to convey the desired way of quantization + @classmethod + @abstractmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + pass