[Quantization] Add metadata porting for nodes added by quantization (#107107)

Summary:
This diff adds adding metadata to q-dq nodes by inferring the
quatization intent from node annotations. Annotations on the node are
way for user to specify how a node or subgraph is supposed to be
quantized. We continue to use that information to copy metadata on Q/DQ
node from appropriate nodes.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48488416](https://our.internmc.facebook.com/intern/diff/D48488416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107107
Approved by: https://github.com/jerryzh168
ghstack dependencies: #107105, #107106, #107899, #107900
This commit is contained in:
Kimish Patel 2023-09-01 13:19:18 -07:00 committed by PyTorch MergeBot
parent d6a9c2b4b5
commit ffc0c46092
3 changed files with 551 additions and 0 deletions

View File

@ -0,0 +1,367 @@
# Owner(s): ["oncall: quantization"]
import copy
import unittest
from typing import List
import torch
import torch._export as export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
class TestHelperModules:
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = x.view(-1, 3)
x = self.linear(x)
return x
def _tag_partitions(
backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
):
for index, partition_nodes in enumerate(annotated_partitions):
tag_name = backend_name + "_" + op_name + "_" + str(index)
for node in partition_nodes:
assert "quantization_tag" not in node.meta, f"{node} is already tagged"
node.meta["quantization_tag"] = tag_name
_QUANT_OPS = {
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.choose_qparams.tensor,
}
class TestMetaDataPorting(QuantizationTestCase):
def _test_metadata_porting(
self,
model,
example_inputs,
quantizer,
node_tags=None,
):
m_eager = model.eval()
# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
pt2_quant_output = m(*example_inputs)
recorded_node_tags = {}
for n in m.graph.nodes:
if (
n.op == "call_function"
and n.target in _QUANT_OPS
and "quantization_tag" in n.meta
):
if n.target not in recorded_node_tags:
recorded_node_tags[n.target] = set()
if n.meta["quantization_tag"] in recorded_node_tags[n.target]:
raise ValueError(
f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type"
)
recorded_node_tags[n.target].add(n.meta["quantization_tag"])
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
for k, v in recorded_node_tags.items():
self.assertEqual(v, node_tags[k])
def test_simple_metadata_porting(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
)
_tag_partitions(
backend_string, "adaptive_avg_pool2d", annotated_partitions
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
"BackendA_linear_0",
}
dequantize_per_tensor_tags = {
"BackendA_adaptive_avg_pool2d_0",
"BackendA_conv2d_0",
"BackendA_linear_0",
}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_with_no_quant_inbetween(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Dont quantize avgpool
Check quantization tags on conv2d and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
@unittest.skip("Temporarily disabled")
def test_metadata_porting_for_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Quantize all except linear.
Quantize linear with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# static quantiazation
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
)
_tag_partitions(
backend_string, "adaptive_avg_pool2d", annotated_partitions
)
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tags = {
"BackendA_adaptive_avg_pool2d_0",
"BackendA_conv2d_0",
}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {
"BackendA_conv2d_0",
"BackendA_linear_dynamic_0",
}
node_tags = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_for_two_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Quantize linear and conv with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
choose_qparams_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
quantize_per_tensor_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
dequantize_per_tensor_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
dequantize_per_channel_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
node_tags = {
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_for_dq_no_static_q(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Dont quantize anything except linear.
Quantize linear with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
node_tags = {
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)

View File

@ -0,0 +1,181 @@
import logging
from typing import Optional
import torch
from torch._export.error import InternalError
from torch._export.pass_base import _ExportPassBase
from torch.ao.quantization.pt2e.utils import (
_filter_sym_size_users,
_find_q_dq_node_for_user,
_is_valid_annotation,
)
from torch.ao.quantization.quantizer import QuantizationSpecBase
from torch.fx.passes.infra.pass_base import PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ["PortNodeMetaForQDQ"]
_METADATA_TO_PORT = [
"nn_module_stack",
"stack_trace",
"quantization_tag",
]
_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
from_meta = from_node.meta
for meta_name in _METADATA_TO_PORT:
if meta_name in from_meta:
to_node.meta[meta_name] = from_meta[meta_name]
def _has_quant_annotation(node: torch.fx.Node) -> bool:
return "quantization_annotation" in node.meta
def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
# BFS to look for choose qparams
from collections import deque
queue = deque(list(node.users.keys()))
while len(queue):
n = queue.popleft()
if n.op == "output":
continue
if (
n.op == "call_function"
and n.target == torch.ops.quantized_decomposed.choose_qparams.tensor
):
return n
for k in n.users.keys():
queue.append(k)
return None
def _port_metadata_for_input_quant_nodes(
input_node: torch.fx.Node,
node: torch.fx.Node,
qspec: Optional[QuantizationSpecBase],
):
if qspec is None:
return
is_dynamic_quant = getattr(qspec, "is_dynamic", None)
if is_dynamic_quant is not None and is_dynamic_quant is True:
choose_qparams_node = _find_choose_qparams_node(input_node)
if choose_qparams_node is None:
raise ValueError(f"No chose qparams node found for {node}")
choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
if len(choose_qparam_users) != 2:
raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
scale_node = choose_qparam_users.pop()
dynamic_q_node = list(scale_node.users.keys())[0]
dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
if len(dynamic_q_node_users) > 1:
raise InternalError(f"Expecting single user for {dynamic_q_node}")
dynamic_dq_node = dynamic_q_node_users.pop()
_add_metadata(choose_qparams_node, node)
_add_metadata(dynamic_q_node, node)
_add_metadata(dynamic_dq_node, node)
else:
q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
if q_node is None or dq_node is None:
return
_add_metadata(dq_node, node)
def _port_metadata_for_output_quant_nodes(
node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
):
if qspec is None:
return
node_users = _filter_sym_size_users(node)
if len(node_users) != 1:
raise InternalError(f"Expecting {node} to have single user")
q_node = node_users.pop()
if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
logger.warning(
f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004
) # noqa: G004
return
_add_metadata(q_node, node)
class PortNodeMetaForQDQ(_ExportPassBase):
"""
Port metadata for nodes added by quantization flow.
For static quant these are:
- quantizer_per_tensor.default, dequantize_per_tensor.default
- quantizer_per_channel.default, dequantize_per_channel.default
For dynamic quant these are:
- choose_qparams.tensor
- quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
- quantizer_per_channel.default, dequantize_per_channel.default
Rules of porting metadata:
- Metadata to be ported:
- nn_module_stack
- stack_trace
- quantization_tag
- Metadata to NOT be ported:
- Everything else
- Rules:
- Statically quantized patterns:
- Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
- Quantize nodes on the outputs inherit metadata of the producer node.
- Example 1:
- Original: [Conv -> AvgPool -> Linear]
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
- Inner brackets specify which nodes Q/DQ inherit metdata from
- [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
- Note first Q and last DQ do not inherit metadata from any nodes
- Example 2:
- Original: [Conv -> AvgPool -> Linear]
- AvgPool is not quantized
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
- Inner brackets specify which nodes Q/DQ inherit metdata from
- [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
- Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
on the nodes (in this case AvgPool node) to conclude if the the node or patter was
supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
inherit metadata from AvgPool.
- Dynamically quantized patterns:
- Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
- For example, below linear is dynamically quantized while rest statically:
- Original: [Conv -> AvgPool -> Linear]
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
- Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
- Note first Q does not inherit metadata from any nodes
"""
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
annotation = node.meta.get("quantization_annotation", None)
if _is_valid_annotation(annotation):
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
output_qspec = node.meta["quantization_annotation"].output_qspec
for input_node, qspec in input_qspec_map.items():
_port_metadata_for_input_quant_nodes(input_node, node, qspec)
_port_metadata_for_output_quant_nodes(node, output_qspec)
return PassResult(graph_module, True)

View File

@ -28,6 +28,7 @@ from typing import Any, Tuple
from torch.fx.passes.infra.pass_manager import PassManager from torch.fx.passes.infra.pass_manager import PassManager
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
__all__ = [ __all__ = [
"prepare_pt2e", "prepare_pt2e",
@ -99,6 +100,8 @@ def convert_pt2e(
pm = PassManager([DuplicateDQPass()]) pm = PassManager([DuplicateDQPass()])
model = pm(model).graph_module model = pm(model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
model = pm(model).graph_module
if use_reference_representation: if use_reference_representation:
model = reference_representation_rewrite(model) model = reference_representation_rewrite(model)