From 70ca18f8a03e8b5fbc48debdba498d8d9247d0ec Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 25 Aug 2023 20:54:20 +0800 Subject: [PATCH] [Quant][PT2E] Enable X86InductorQuantizer single quantizable op(maxpool2d) (#105639) **Summary** In this PR, we mainly enable 2 things. - Enable the skeleton of quantization recipe for single quantizable operators in `X86InductorQuantizer`. - Add quantization recipe of `maxpool2d` and annotate it as input./output share observer. **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_maxpool2d_recipe ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/105639 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456 --- .../pt2e/test_x86inductor_quantizer.py | 81 +++++- .../quantizer/x86_inductor_quantizer.py | 245 ++++++++++++++++-- 2 files changed, 310 insertions(+), 16 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index dad00f2998a..c6d98060442 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -20,7 +20,8 @@ from torch.testing._internal.common_quantized import override_quantized_engine from enum import Enum import itertools import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq - +import operator +from torch.ao.quantization import ObserverBase class Conv2DType(Enum): left = 1 @@ -127,6 +128,18 @@ class TestHelperModules: else: return self.relu2(self.conv(x) + self.conv2(x)) + class Conv2dMaxpoolPowModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + self.pool = nn.MaxPool2d(1, 1) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + return torch.pow(x, 2) + + class SerialsConv2dAddReLUModule(torch.nn.Module): """ Serials of 2 Conv2d -> Add -> ReLU Pattern. """ @@ -171,10 +184,13 @@ class X86InductorQuantTestCase(QuantizationTestCase): *copy.deepcopy(example_inputs), aten_graph=True, ) + export_model = copy.deepcopy(m) m = prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) + prepare_model = copy.deepcopy(m) m = convert_pt2e(m) + convert_model = copy.deepcopy(m) pt2_quant_output = m(*example_inputs) node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() @@ -185,6 +201,7 @@ class X86InductorQuantTestCase(QuantizationTestCase): self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list ) + return export_model, prepare_model, convert_model @skipIfNoDynamoSupport class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): @@ -400,3 +417,65 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): node_occurrence, node_list, ) + + @skipIfNoX86 + def test_maxpool2d_recipe(self): + r""" + Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow) + Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow. + """ + m = TestHelperModules.Conv2dMaxpoolPowModule().eval() + x = torch.rand(1, 2, 14, 14) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + # one for input and weight of the conv, two for input/output for the maxpool2d + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.convolution.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + # Check Maxpool2d has share observer at input and output + for node in prepare_model.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.aten.max_pool2d_with_indices.default + ): + maxpool_node = node + input_obs_of_maxpool = getattr( + prepare_model, maxpool_node.args[0].target + ) + elif node.op == "call_function" and node.target is operator.getitem: + output_obs_of_maxpool = getattr( + prepare_model, list(node.users)[0].target + ) + elif ( + node.op == "call_function" + and node.target is torch.ops.aten.convolution.default + ): + conv_node = node + input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) + self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase)) + self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) + self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool) + self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 8195ce2ec13..9f03af9899d 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -2,7 +2,8 @@ import copy import functools import itertools import operator -from typing import Any, Dict, List, Optional, Set +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple import torch import torch.nn.functional as F @@ -13,6 +14,12 @@ from torch.ao.quantization.observer import ( ) from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _is_annotated, get_bias_qspec, @@ -28,7 +35,6 @@ from torch.fx.passes.utils.source_matcher_utils import ( get_source_partitions, SourcePartition, ) -from .quantizer import QuantizationAnnotation, QuantizationSpec, Quantizer __all__ = [ "X86InductorQuantizer", @@ -36,6 +42,70 @@ __all__ = [ ] +@dataclass +class _X86InductorQuantizationAnnotation(QuantizationAnnotation): + # _is_output_of_quantized_pattern: + # * Node as output node of a fusion pattern. + # * The fusion pattern supports int8 data type. + # * The fusion pattern has inputs annotated to insert observer. + _is_output_of_quantized_pattern: bool = False + + +# Ops support int8 data type and excludes ops like conv, linear. +quantizable_ops_pt2e: Set = { + torch.ops.aten.max_pool2d_with_indices.default, +} + + +# Ops that: +# 1. Ops prefer to run with int8 when int8 input is given. +# 2. Ops don't support int8 in and fp32 out. +int8_in_int8_out_ops_pt2e: Set = { + torch.ops.aten.max_pool2d_with_indices.default, +} + + +def _is_node_annotated(_node): + """ + return True if the node is annotated, otherwise return False + """ + return ( + "quantization_annotation" in _node.meta + and _node.meta["quantization_annotation"]._annotated + ) + + +def _is_any_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False. + """ + return any(_is_node_annotated(node) for node in nodes) + + +def _is_all_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if all of the node is annotated, otherwise return False. + """ + return all(_is_node_annotated(node) for node in nodes) + + +def _is_quantized_op_pt2e(node: torch.fx.Node): + """ + Used for pt2e flow to check if the node is a quantized node: + Case1: the node has been annotated as output node of a fusion pattern. + Case2: the node has been annotated as single quantized node. + """ + if not _is_any_annotated([node]): + # The node has not been annotated, directly return False + return False + quantization_annotation = node.meta.get("quantization_annotation", None) + assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) + return quantization_annotation._is_output_of_quantized_pattern + + def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: # TODO: Add more supported operators here. supported_operators: Dict[str, List[OperatorPatternType]] = { @@ -188,15 +258,21 @@ class X86InductorQuantizer(Quantizer): if isinstance(bias_node, Node): input_qspec_map[bias_node] = get_bias_qspec(quantization_config) if annotate_output: - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + conv_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( input_qspec_map=input_qspec_map, # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), _annotated=True, + _is_output_of_quantized_pattern=True, ) else: - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, _annotated=True + conv_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, ) def _get_output_nodes_of_partitions( @@ -249,16 +325,45 @@ class X86InductorQuantizer(Quantizer): def _annotate_for_static_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - # annotate the nodes from last to first since the matching is in the reversed order - # and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv) - # and we will mark the matched node with "_annoated" so fusion operator pattern - # can take precedence over single operator pattern in this way + r""" + High-level description of quantization recipe for X86 Inductor Backend: + Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. + Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model + from start to the end. If a pattern supports computation with int8 data type and inputs connected to + quantized patterns, annotate its inputs as quantized pattern. + Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns, + such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, + we need to annotate the output of this pattern. + """ + config = self.global_config + + # Step1: Recipe of fusion patterns like conv/linear. + self._annotate_conv2d_fusion_pattern(model, config) + + # Step2: Recipe to propagate annotation for patterns beside conv/linear. + # Go through all the nodes from start to end. + # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 + for node in model.graph.nodes: + self._annotation_propagation_quantizable_pattern(node, config) + + # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized + # in inputs. So, we can fuse dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node, config) + + return model + + def _annotate_conv2d_fusion_pattern( + self, model: torch.fx.GraphModule, config: QuantizationConfig + ): self._annotate_conv2d_binary_unary(model, config) self._annotate_conv2d_binary(model, config) self._annotate_conv2d_unary(model, config) self._annotate_conv2d(model, config) - return model def _annotate_conv2d_binary_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig @@ -293,13 +398,19 @@ class X86InductorQuantizer(Quantizer): binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, _annotated=True + binary_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, ) - unary_node.meta["quantization_annotation"] = QuantizationAnnotation( + unary_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d_binary( @@ -336,11 +447,14 @@ class X86InductorQuantizer(Quantizer): binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta["quantization_annotation"] = QuantizationAnnotation( + binary_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( input_qspec_map=binary_node_input_qspec_map, # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d_unary( @@ -362,10 +476,13 @@ class X86InductorQuantizer(Quantizer): if _is_annotated([unary_node, conv_node]): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - unary_node.meta["quantization_annotation"] = QuantizationAnnotation( + unary_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, + _is_output_of_quantized_pattern=True, ) def _annotate_conv2d( @@ -389,6 +506,104 @@ class X86InductorQuantizer(Quantizer): continue self._annotate_conv_node_helper(conv_node, True, quantization_config) + def _annotate_maxpool2d( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if node.target is not torch.ops.aten.max_pool2d_with_indices.default or not ( + len(list(node.users)) == 1 + and (list(node.users)[0].target == operator.getitem) + ): + return + maxpool_node = node + getitem_node = list(node.users)[0] + if _is_any_annotated([getitem_node, maxpool_node]): + return + input_node = maxpool_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + maxpool_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + getitem_node.meta[ + "quantization_annotation" + ] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotation_propagation_quantizable_pattern( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + # Propagate annotation to quantizable patterns. + if ( + (node.target in quantizable_ops_pt2e) + and (not _is_any_annotated([node])) + and (node.op == "call_function") + ): + + def is_all_inputs_connected_to_quantized_op(input_nodes): + # Ensure all the inputs connect to fusion pattern or quantized node + for input_node in input_nodes: + if not _is_quantized_op_pt2e(input_node): + return False + return True + + if node.target is torch.ops.aten.max_pool2d_with_indices.default: + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + return + self._annotate_maxpool2d(node, quantization_config) + return + else: + # TODO : Enable recipes for more single quantizable op such as view and relu. + pass + return + + def _annotate_output_for_int8_in_int8_out_pattern( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + r""" + Check and insert observer at output of node in int8_in_int8_out_ops_pt2e if needed. + Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ + 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 + """ + if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])): + if node.target == torch.ops.aten.max_pool2d_with_indices.default: + maxpool_node = node + assert len(list(maxpool_node.users)) == 1 and ( + list(maxpool_node.users)[0].target == operator.getitem + ) + getitem_node = list(node.users)[0] + if not _is_all_annotated([getitem_node, maxpool_node]): + return + # Get the quantization_annotation from getitem_node + getitem_quantization_annotation = ( + getitem_node.meta["quantization_annotation"] + if "quantization_annotation" in getitem_node.meta + else None + ) + if ( + getitem_quantization_annotation + and getitem_quantization_annotation._is_output_of_quantized_pattern + ): + # Annotate the output_qspec of getitem_node + input_act = maxpool_node.args[0] + assert isinstance(input_act, Node) + assert isinstance(maxpool_node, Node) + edge_or_node: Tuple[Node, Node] = (input_act, maxpool_node) + getitem_node.meta[ + "quantization_annotation" + ].output_qspec = SharedQuantizationSpec(edge_or_node) + else: + # TODO : Enable recipes for more int8_in_int8_out_ops + pass + return + def validate(self, model: torch.fx.GraphModule) -> None: pass