diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py new file mode 100644 index 00000000000..900e60733f4 --- /dev/null +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -0,0 +1,367 @@ +# Owner(s): ["oncall: quantization"] +import copy + +import unittest +from typing import List + +import torch +import torch._export as export +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR + +from torch.fx import Node + +from torch.testing._internal.common_quantization import QuantizationTestCase + + +class TestHelperModules: + class Conv2dWithObsSharingOps(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.hardtanh = torch.nn.Hardtanh() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + x = self.hardtanh(x) + x = x.view(-1, 3) + x = self.linear(x) + return x + + +def _tag_partitions( + backend_name: str, op_name: str, annotated_partitions: List[List[Node]] +): + for index, partition_nodes in enumerate(annotated_partitions): + tag_name = backend_name + "_" + op_name + "_" + str(index) + for node in partition_nodes: + assert "quantization_tag" not in node.meta, f"{node} is already tagged" + node.meta["quantization_tag"] = tag_name + + +_QUANT_OPS = { + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.choose_qparams.tensor, +} + + +class TestMetaDataPorting(QuantizationTestCase): + def _test_metadata_porting( + self, + model, + example_inputs, + quantizer, + node_tags=None, + ): + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + m = export.capture_pre_autograd_graph( + m, + example_inputs, + ) + + m = prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + + pt2_quant_output = m(*example_inputs) + recorded_node_tags = {} + for n in m.graph.nodes: + if ( + n.op == "call_function" + and n.target in _QUANT_OPS + and "quantization_tag" in n.meta + ): + if n.target not in recorded_node_tags: + recorded_node_tags[n.target] = set() + if n.meta["quantization_tag"] in recorded_node_tags[n.target]: + raise ValueError( + f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type" + ) + recorded_node_tags[n.target].add(n.meta["quantization_tag"]) + self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys())) + for k, v in recorded_node_tags.items(): + self.assertEqual(v, node_tags[k]) + + def test_simple_metadata_porting(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "linear", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "conv2d", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( + gm, quantization_config + ) + _tag_partitions( + backend_string, "adaptive_avg_pool2d", annotated_partitions + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + quantize_per_tensor_tags = { + "BackendA_conv2d_0", + "BackendA_adaptive_avg_pool2d_0", + "BackendA_linear_0", + } + dequantize_per_tensor_tags = { + "BackendA_adaptive_avg_pool2d_0", + "BackendA_conv2d_0", + "BackendA_linear_0", + } + dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + node_tags = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + def test_metadata_porting_with_no_quant_inbetween(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Dont quantize avgpool + Check quantization tags on conv2d and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "linear", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "conv2d", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + node_tags = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + @unittest.skip("Temporarily disabled") + def test_metadata_porting_for_dq(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Quantize all except linear. + Quantize linear with dynamic quantization + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + # static quantiazation + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "conv2d", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( + gm, quantization_config + ) + _tag_partitions( + backend_string, "adaptive_avg_pool2d", annotated_partitions + ) + + # dynamic quantization + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + quantize_per_tensor_tags = { + "BackendA_conv2d_0", + "BackendA_adaptive_avg_pool2d_0", + } + quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_tensor_tags = { + "BackendA_adaptive_avg_pool2d_0", + "BackendA_conv2d_0", + } + dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_channel_tags = { + "BackendA_conv2d_0", + "BackendA_linear_dynamic_0", + } + node_tags = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + def test_metadata_porting_for_two_dq(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Quantize linear and conv with dynamic quantization + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + + # dynamic quantization + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + choose_qparams_tensor_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + quantize_per_tensor_tensor_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + dequantize_per_tensor_tensor_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + dequantize_per_channel_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + node_tags = { + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + def test_metadata_porting_for_dq_no_static_q(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Dont quantize anything except linear. + Quantize linear with dynamic quantization + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + # dynamic quantization + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"} + quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"} + node_tags = { + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py new file mode 100644 index 00000000000..b0fe72f6787 --- /dev/null +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -0,0 +1,181 @@ +import logging +from typing import Optional + +import torch +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBase + +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _find_q_dq_node_for_user, + _is_valid_annotation, +) + +from torch.ao.quantization.quantizer import QuantizationSpecBase + +from torch.fx.passes.infra.pass_base import PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["PortNodeMetaForQDQ"] + +_METADATA_TO_PORT = [ + "nn_module_stack", + "stack_trace", + "quantization_tag", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: + from_meta = from_node.meta + for meta_name in _METADATA_TO_PORT: + if meta_name in from_meta: + to_node.meta[meta_name] = from_meta[meta_name] + + +def _has_quant_annotation(node: torch.fx.Node) -> bool: + return "quantization_annotation" in node.meta + + +def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: + # BFS to look for choose qparams + from collections import deque + + queue = deque(list(node.users.keys())) + while len(queue): + n = queue.popleft() + if n.op == "output": + continue + if ( + n.op == "call_function" + and n.target == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + return n + for k in n.users.keys(): + queue.append(k) + return None + + +def _port_metadata_for_input_quant_nodes( + input_node: torch.fx.Node, + node: torch.fx.Node, + qspec: Optional[QuantizationSpecBase], +): + if qspec is None: + return + + is_dynamic_quant = getattr(qspec, "is_dynamic", None) + if is_dynamic_quant is not None and is_dynamic_quant is True: + choose_qparams_node = _find_choose_qparams_node(input_node) + if choose_qparams_node is None: + raise ValueError(f"No chose qparams node found for {node}") + choose_qparam_users = _filter_sym_size_users(choose_qparams_node) + if len(choose_qparam_users) != 2: + raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") + scale_node = choose_qparam_users.pop() + dynamic_q_node = list(scale_node.users.keys())[0] + dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) + if len(dynamic_q_node_users) > 1: + raise InternalError(f"Expecting single user for {dynamic_q_node}") + dynamic_dq_node = dynamic_q_node_users.pop() + _add_metadata(choose_qparams_node, node) + _add_metadata(dynamic_q_node, node) + _add_metadata(dynamic_dq_node, node) + else: + q_node, dq_node = _find_q_dq_node_for_user(input_node, node) + if q_node is None or dq_node is None: + return + _add_metadata(dq_node, node) + + +def _port_metadata_for_output_quant_nodes( + node: torch.fx.Node, qspec: Optional[QuantizationSpecBase] +): + if qspec is None: + return + + node_users = _filter_sym_size_users(node) + if len(node_users) != 1: + raise InternalError(f"Expecting {node} to have single user") + q_node = node_users.pop() + if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: + logger.warning( + f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 + ) # noqa: G004 + return + + _add_metadata(q_node, node) + + +class PortNodeMetaForQDQ(_ExportPassBase): + """ + Port metadata for nodes added by quantization flow. + For static quant these are: + - quantizer_per_tensor.default, dequantize_per_tensor.default + - quantizer_per_channel.default, dequantize_per_channel.default + For dynamic quant these are: + - choose_qparams.tensor + - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor + - quantizer_per_channel.default, dequantize_per_channel.default + + Rules of porting metadata: + - Metadata to be ported: + - nn_module_stack + - stack_trace + - quantization_tag + - Metadata to NOT be ported: + - Everything else + - Rules: + - Statically quantized patterns: + - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. + - Quantize nodes on the outputs inherit metadata of the producer node. + - Example 1: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] + - Note first Q and last DQ do not inherit metadata from any nodes + - Example 2: + - Original: [Conv -> AvgPool -> Linear] + - AvgPool is not quantized + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the the node or patter was + supposed to be quantized. And subsequntly decide if the preceding Q, if any, should + inherit metadata from AvgPool. + - Dynamically quantized patterns: + - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes + - For example, below linear is dynamically quantized while rest statically: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] + - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] + - Note first Q does not inherit metadata from any nodes + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + annotation = node.meta.get("quantization_annotation", None) + if _is_valid_annotation(annotation): + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + output_qspec = node.meta["quantization_annotation"].output_qspec + for input_node, qspec in input_qspec_map.items(): + _port_metadata_for_input_quant_nodes(input_node, node, qspec) + _port_metadata_for_output_quant_nodes(node, output_qspec) + return PassResult(graph_module, True) diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 7da89f483ce..610bcf51039 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -28,6 +28,7 @@ from typing import Any, Tuple from torch.fx.passes.infra.pass_manager import PassManager from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ __all__ = [ "prepare_pt2e", @@ -99,6 +100,8 @@ def convert_pt2e( pm = PassManager([DuplicateDQPass()]) model = pm(model).graph_module + pm = PassManager([PortNodeMetaForQDQ()]) + model = pm(model).graph_module if use_reference_representation: model = reference_representation_rewrite(model)