mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quantization] Add metadata porting for nodes added by quantization (#107107)
Summary: This diff adds adding metadata to q-dq nodes by inferring the quatization intent from node annotations. Annotations on the node are way for user to specify how a node or subgraph is supposed to be quantized. We continue to use that information to copy metadata on Q/DQ node from appropriate nodes. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D48488416](https://our.internmc.facebook.com/intern/diff/D48488416) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107107 Approved by: https://github.com/jerryzh168 ghstack dependencies: #107105, #107106, #107899, #107900
This commit is contained in:
parent
d6a9c2b4b5
commit
ffc0c46092
367
test/quantization/pt2e/test_metadata_porting.py
Normal file
367
test/quantization/pt2e/test_metadata_porting.py
Normal file
|
|
@ -0,0 +1,367 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._export as export
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import Quantizer
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
|
||||
|
||||
from torch.fx import Node
|
||||
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
class Conv2dWithObsSharingOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.hardtanh = torch.nn.Hardtanh()
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
x = self.hardtanh(x)
|
||||
x = x.view(-1, 3)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def _tag_partitions(
|
||||
backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
|
||||
):
|
||||
for index, partition_nodes in enumerate(annotated_partitions):
|
||||
tag_name = backend_name + "_" + op_name + "_" + str(index)
|
||||
for node in partition_nodes:
|
||||
assert "quantization_tag" not in node.meta, f"{node} is already tagged"
|
||||
node.meta["quantization_tag"] = tag_name
|
||||
|
||||
|
||||
_QUANT_OPS = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor,
|
||||
}
|
||||
|
||||
|
||||
class TestMetaDataPorting(QuantizationTestCase):
|
||||
def _test_metadata_porting(
|
||||
self,
|
||||
model,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_tags=None,
|
||||
):
|
||||
m_eager = model.eval()
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export.capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
recorded_node_tags = {}
|
||||
for n in m.graph.nodes:
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target in _QUANT_OPS
|
||||
and "quantization_tag" in n.meta
|
||||
):
|
||||
if n.target not in recorded_node_tags:
|
||||
recorded_node_tags[n.target] = set()
|
||||
if n.meta["quantization_tag"] in recorded_node_tags[n.target]:
|
||||
raise ValueError(
|
||||
f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type"
|
||||
)
|
||||
recorded_node_tags[n.target].add(n.meta["quantization_tag"])
|
||||
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
|
||||
for k, v in recorded_node_tags.items():
|
||||
self.assertEqual(v, node_tags[k])
|
||||
|
||||
def test_simple_metadata_porting(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> avgpool -> hardtanh -> linear
|
||||
Check quantization tags on conv2d, avgpool and linear are correctly set
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(backend_string, "linear", annotated_partitions)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(
|
||||
backend_string, "adaptive_avg_pool2d", annotated_partitions
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
quantize_per_tensor_tags = {
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_adaptive_avg_pool2d_0",
|
||||
"BackendA_linear_0",
|
||||
}
|
||||
dequantize_per_tensor_tags = {
|
||||
"BackendA_adaptive_avg_pool2d_0",
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_linear_0",
|
||||
}
|
||||
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
node_tags = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
}
|
||||
self._test_metadata_porting(
|
||||
TestHelperModules.Conv2dWithObsSharingOps(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
|
||||
def test_metadata_porting_with_no_quant_inbetween(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> avgpool -> hardtanh -> linear
|
||||
Dont quantize avgpool
|
||||
Check quantization tags on conv2d and linear are correctly set
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(backend_string, "linear", annotated_partitions)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
|
||||
node_tags = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
}
|
||||
self._test_metadata_porting(
|
||||
TestHelperModules.Conv2dWithObsSharingOps(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
|
||||
@unittest.skip("Temporarily disabled")
|
||||
def test_metadata_porting_for_dq(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> avgpool -> hardtanh -> linear
|
||||
Quantize all except linear.
|
||||
Quantize linear with dynamic quantization
|
||||
Check quantization tags on conv2d, avgpool and linear are correctly set
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
# static quantiazation
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(backend_string, "conv2d", annotated_partitions)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
|
||||
gm, quantization_config
|
||||
)
|
||||
_tag_partitions(
|
||||
backend_string, "adaptive_avg_pool2d", annotated_partitions
|
||||
)
|
||||
|
||||
# dynamic quantization
|
||||
quantization_config_dynamic = get_symmetric_quantization_config(
|
||||
is_per_channel=True, is_dynamic=True
|
||||
)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
||||
gm, quantization_config_dynamic
|
||||
)
|
||||
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
quantize_per_tensor_tags = {
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_adaptive_avg_pool2d_0",
|
||||
}
|
||||
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
dequantize_per_tensor_tags = {
|
||||
"BackendA_adaptive_avg_pool2d_0",
|
||||
"BackendA_conv2d_0",
|
||||
}
|
||||
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
dequantize_per_channel_tags = {
|
||||
"BackendA_conv2d_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
node_tags = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
|
||||
}
|
||||
self._test_metadata_porting(
|
||||
TestHelperModules.Conv2dWithObsSharingOps(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
|
||||
def test_metadata_porting_for_two_dq(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> avgpool -> hardtanh -> linear
|
||||
Quantize linear and conv with dynamic quantization
|
||||
Check quantization tags on conv2d, avgpool and linear are correctly set
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
|
||||
# dynamic quantization
|
||||
quantization_config_dynamic = get_symmetric_quantization_config(
|
||||
is_per_channel=True, is_dynamic=True
|
||||
)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
|
||||
gm, quantization_config_dynamic
|
||||
)
|
||||
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
||||
gm, quantization_config_dynamic
|
||||
)
|
||||
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
choose_qparams_tensor_tags = {
|
||||
"BackendA_conv2d_dynamic_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
quantize_per_tensor_tensor_tags = {
|
||||
"BackendA_conv2d_dynamic_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
dequantize_per_tensor_tensor_tags = {
|
||||
"BackendA_conv2d_dynamic_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
dequantize_per_channel_tags = {
|
||||
"BackendA_conv2d_dynamic_0",
|
||||
"BackendA_linear_dynamic_0",
|
||||
}
|
||||
node_tags = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
|
||||
}
|
||||
self._test_metadata_porting(
|
||||
TestHelperModules.Conv2dWithObsSharingOps(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
|
||||
def test_metadata_porting_for_dq_no_static_q(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> avgpool -> hardtanh -> linear
|
||||
Dont quantize anything except linear.
|
||||
Quantize linear with dynamic quantization
|
||||
Check quantization tags on conv2d, avgpool and linear are correctly set
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
# dynamic quantization
|
||||
quantization_config_dynamic = get_symmetric_quantization_config(
|
||||
is_per_channel=True, is_dynamic=True
|
||||
)
|
||||
annotated_partitions = OP_TO_ANNOTATOR["linear"](
|
||||
gm, quantization_config_dynamic
|
||||
)
|
||||
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
|
||||
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
|
||||
node_tags = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
|
||||
}
|
||||
self._test_metadata_porting(
|
||||
TestHelperModules.Conv2dWithObsSharingOps(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_tags,
|
||||
)
|
||||
181
torch/ao/quantization/pt2e/port_metadata_pass.py
Normal file
181
torch/ao/quantization/pt2e/port_metadata_pass.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._export.error import InternalError
|
||||
from torch._export.pass_base import _ExportPassBase
|
||||
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_filter_sym_size_users,
|
||||
_find_q_dq_node_for_user,
|
||||
_is_valid_annotation,
|
||||
)
|
||||
|
||||
from torch.ao.quantization.quantizer import QuantizationSpecBase
|
||||
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
__all__ = ["PortNodeMetaForQDQ"]
|
||||
|
||||
_METADATA_TO_PORT = [
|
||||
"nn_module_stack",
|
||||
"stack_trace",
|
||||
"quantization_tag",
|
||||
]
|
||||
|
||||
_QUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
|
||||
from_meta = from_node.meta
|
||||
for meta_name in _METADATA_TO_PORT:
|
||||
if meta_name in from_meta:
|
||||
to_node.meta[meta_name] = from_meta[meta_name]
|
||||
|
||||
|
||||
def _has_quant_annotation(node: torch.fx.Node) -> bool:
|
||||
return "quantization_annotation" in node.meta
|
||||
|
||||
|
||||
def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
|
||||
# BFS to look for choose qparams
|
||||
from collections import deque
|
||||
|
||||
queue = deque(list(node.users.keys()))
|
||||
while len(queue):
|
||||
n = queue.popleft()
|
||||
if n.op == "output":
|
||||
continue
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target == torch.ops.quantized_decomposed.choose_qparams.tensor
|
||||
):
|
||||
return n
|
||||
for k in n.users.keys():
|
||||
queue.append(k)
|
||||
return None
|
||||
|
||||
|
||||
def _port_metadata_for_input_quant_nodes(
|
||||
input_node: torch.fx.Node,
|
||||
node: torch.fx.Node,
|
||||
qspec: Optional[QuantizationSpecBase],
|
||||
):
|
||||
if qspec is None:
|
||||
return
|
||||
|
||||
is_dynamic_quant = getattr(qspec, "is_dynamic", None)
|
||||
if is_dynamic_quant is not None and is_dynamic_quant is True:
|
||||
choose_qparams_node = _find_choose_qparams_node(input_node)
|
||||
if choose_qparams_node is None:
|
||||
raise ValueError(f"No chose qparams node found for {node}")
|
||||
choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
|
||||
if len(choose_qparam_users) != 2:
|
||||
raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
|
||||
scale_node = choose_qparam_users.pop()
|
||||
dynamic_q_node = list(scale_node.users.keys())[0]
|
||||
dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
|
||||
if len(dynamic_q_node_users) > 1:
|
||||
raise InternalError(f"Expecting single user for {dynamic_q_node}")
|
||||
dynamic_dq_node = dynamic_q_node_users.pop()
|
||||
_add_metadata(choose_qparams_node, node)
|
||||
_add_metadata(dynamic_q_node, node)
|
||||
_add_metadata(dynamic_dq_node, node)
|
||||
else:
|
||||
q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
|
||||
if q_node is None or dq_node is None:
|
||||
return
|
||||
_add_metadata(dq_node, node)
|
||||
|
||||
|
||||
def _port_metadata_for_output_quant_nodes(
|
||||
node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
|
||||
):
|
||||
if qspec is None:
|
||||
return
|
||||
|
||||
node_users = _filter_sym_size_users(node)
|
||||
if len(node_users) != 1:
|
||||
raise InternalError(f"Expecting {node} to have single user")
|
||||
q_node = node_users.pop()
|
||||
if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
|
||||
logger.warning(
|
||||
f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004
|
||||
) # noqa: G004
|
||||
return
|
||||
|
||||
_add_metadata(q_node, node)
|
||||
|
||||
|
||||
class PortNodeMetaForQDQ(_ExportPassBase):
|
||||
"""
|
||||
Port metadata for nodes added by quantization flow.
|
||||
For static quant these are:
|
||||
- quantizer_per_tensor.default, dequantize_per_tensor.default
|
||||
- quantizer_per_channel.default, dequantize_per_channel.default
|
||||
For dynamic quant these are:
|
||||
- choose_qparams.tensor
|
||||
- quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
|
||||
- quantizer_per_channel.default, dequantize_per_channel.default
|
||||
|
||||
Rules of porting metadata:
|
||||
- Metadata to be ported:
|
||||
- nn_module_stack
|
||||
- stack_trace
|
||||
- quantization_tag
|
||||
- Metadata to NOT be ported:
|
||||
- Everything else
|
||||
- Rules:
|
||||
- Statically quantized patterns:
|
||||
- Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
|
||||
- Quantize nodes on the outputs inherit metadata of the producer node.
|
||||
- Example 1:
|
||||
- Original: [Conv -> AvgPool -> Linear]
|
||||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
|
||||
- Inner brackets specify which nodes Q/DQ inherit metdata from
|
||||
- [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
|
||||
- Note first Q and last DQ do not inherit metadata from any nodes
|
||||
- Example 2:
|
||||
- Original: [Conv -> AvgPool -> Linear]
|
||||
- AvgPool is not quantized
|
||||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
|
||||
- Inner brackets specify which nodes Q/DQ inherit metdata from
|
||||
- [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
|
||||
- Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
|
||||
AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
|
||||
on the nodes (in this case AvgPool node) to conclude if the the node or patter was
|
||||
supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
|
||||
inherit metadata from AvgPool.
|
||||
- Dynamically quantized patterns:
|
||||
- Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
|
||||
- For example, below linear is dynamically quantized while rest statically:
|
||||
- Original: [Conv -> AvgPool -> Linear]
|
||||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
|
||||
- Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
|
||||
- Note first Q does not inherit metadata from any nodes
|
||||
"""
|
||||
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
for node in graph_module.graph.nodes:
|
||||
annotation = node.meta.get("quantization_annotation", None)
|
||||
if _is_valid_annotation(annotation):
|
||||
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
|
||||
output_qspec = node.meta["quantization_annotation"].output_qspec
|
||||
for input_node, qspec in input_qspec_map.items():
|
||||
_port_metadata_for_input_quant_nodes(input_node, node, qspec)
|
||||
_port_metadata_for_output_quant_nodes(node, output_qspec)
|
||||
return PassResult(graph_module, True)
|
||||
|
|
@ -28,6 +28,7 @@ from typing import Any, Tuple
|
|||
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
|
||||
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
|
||||
|
||||
__all__ = [
|
||||
"prepare_pt2e",
|
||||
|
|
@ -99,6 +100,8 @@ def convert_pt2e(
|
|||
pm = PassManager([DuplicateDQPass()])
|
||||
model = pm(model).graph_module
|
||||
|
||||
pm = PassManager([PortNodeMetaForQDQ()])
|
||||
model = pm(model).graph_module
|
||||
if use_reference_representation:
|
||||
model = reference_representation_rewrite(model)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user