pytorch/test/quantization/pt2e/test_metadata_porting.py
Jerry Zhang 1b51d29b66 [quant][pt2e] Enable constant folding for quantize ops (#109343)
Summary:
This PR added constant folding for quantize ops so that instead of storing fp32 weight in the
quantized model, we'll get int8/int16 etc. weight

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_fold_quantize

also will verify in executorch later

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343
Approved by: https://github.com/kimishpatel, https://github.com/jgong5
2023-09-27 06:04:45 +00:00

398 lines
16 KiB
Python

# Owner(s): ["oncall: quantization"]
import copy
import unittest
from typing import List
import torch
import torch._export as export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
class TestHelperModules:
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = x.view(-1, 3)
x = self.linear(x)
return x
def _tag_partitions(
backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
):
for index, partition_nodes in enumerate(annotated_partitions):
tag_name = backend_name + "_" + op_name + "_" + str(index)
for node in partition_nodes:
assert "quantization_tag" not in node.meta, f"{node} is already tagged"
node.meta["quantization_tag"] = tag_name
_QUANT_OPS = {
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.choose_qparams.tensor,
}
# TODO: rename to TestPortMetadataPass to align with the util name?
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestMetaDataPorting(QuantizationTestCase):
def _test_metadata_porting(
self,
model,
example_inputs,
quantizer,
node_tags=None,
):
m_eager = model.eval()
# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)
pt2_quant_output = m(*example_inputs)
recorded_node_tags = {}
for n in m.graph.nodes:
if "quantization_tag" not in n.meta:
continue
if n.op == "call_function" and n.target in _QUANT_OPS:
key = n.target
elif n.op == "get_attr":
key = "get_attr"
else:
continue
if key not in recorded_node_tags:
recorded_node_tags[key] = set()
if (
n.op == "call_function"
and n.meta["quantization_tag"] in recorded_node_tags[key]
):
raise ValueError(
f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
"is associated with another node of the same type"
)
recorded_node_tags[key].add(n.meta["quantization_tag"])
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
for k, v in recorded_node_tags.items():
self.assertEqual(v, node_tags[k])
def test_simple_metadata_porting(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
)
_tag_partitions(
backend_string, "adaptive_avg_pool2d", annotated_partitions
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {
"BackendA_conv2d_0",
"BackendA_linear_0",
}
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
"BackendA_linear_0",
}
dequantize_per_tensor_tags = {
"BackendA_adaptive_avg_pool2d_0",
"BackendA_conv2d_0",
"BackendA_linear_0",
}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_with_no_quant_inbetween(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Dont quantize avgpool
Check quantization tags on conv2d and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
@unittest.skip("Temporarily disabled")
def test_metadata_porting_for_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Quantize all except linear.
Quantize linear with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# static quantiazation
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
)
_tag_partitions(
backend_string, "adaptive_avg_pool2d", annotated_partitions
)
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
# TODO: add get_attr_tags when the test is re-enabled
get_attr_tags = {}
quantize_per_tensor_tags = {
"BackendA_conv2d_0",
"BackendA_adaptive_avg_pool2d_0",
}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tags = {
"BackendA_adaptive_avg_pool2d_0",
"BackendA_conv2d_0",
}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {
"BackendA_conv2d_0",
"BackendA_linear_dynamic_0",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_for_two_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Quantize linear and conv with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
choose_qparams_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
quantize_per_tensor_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
dequantize_per_tensor_tensor_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
dequantize_per_channel_tags = {
"BackendA_conv2d_dynamic_0",
"BackendA_linear_dynamic_0",
}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)
def test_metadata_porting_for_dq_no_static_q(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Dont quantize anything except linear.
Quantize linear with dynamic quantization
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
# dynamic quantization
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["linear"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
get_attr_tags = {"BackendA_linear_dynamic_0"}
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
node_tags = {
"get_attr": get_attr_tags,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
}
self._test_metadata_porting(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
node_tags,
)