[Quant][PT2E] Enable X86InductorQuantizer single quantizable op(maxpool2d) (#105639)

**Summary**
In this PR, we mainly enable 2 things.

- Enable the skeleton of quantization recipe for single quantizable operators in `X86InductorQuantizer`.
- Add quantization recipe of `maxpool2d` and annotate it as input./output share observer.

**Test Plan**
```
python -m pytest test_x86inductor_quantizer.py -k test_maxpool2d_recipe
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105639
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456
This commit is contained in:
leslie-fang-intel 2023-08-25 20:54:20 +08:00 committed by PyTorch MergeBot
parent c5ad44be1d
commit 70ca18f8a0
2 changed files with 310 additions and 16 deletions

View File

@ -20,7 +20,8 @@ from torch.testing._internal.common_quantized import override_quantized_engine
from enum import Enum
import itertools
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import operator
from torch.ao.quantization import ObserverBase
class Conv2DType(Enum):
left = 1
@ -127,6 +128,18 @@ class TestHelperModules:
else:
return self.relu2(self.conv(x) + self.conv2(x))
class Conv2dMaxpoolPowModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
self.pool = nn.MaxPool2d(1, 1)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
return torch.pow(x, 2)
class SerialsConv2dAddReLUModule(torch.nn.Module):
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
"""
@ -171,10 +184,13 @@ class X86InductorQuantTestCase(QuantizationTestCase):
*copy.deepcopy(example_inputs),
aten_graph=True,
)
export_model = copy.deepcopy(m)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
prepare_model = copy.deepcopy(m)
m = convert_pt2e(m)
convert_model = copy.deepcopy(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
ns.call_function(k): v for k, v in expected_node_occurrence.items()
@ -185,6 +201,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)
return export_model, prepare_model, convert_model
@skipIfNoDynamoSupport
class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
@ -400,3 +417,65 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence,
node_list,
)
@skipIfNoX86
def test_maxpool2d_recipe(self):
r"""
Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow)
Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow.
"""
m = TestHelperModules.Conv2dMaxpoolPowModule().eval()
x = torch.rand(1, 2, 14, 14)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
example_inputs = (x,)
node_occurrence = {
# one for input and weight of the conv, two for input/output for the maxpool2d
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.max_pool2d_with_indices.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
_, prepare_model, _ = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
# Check Maxpool2d has share observer at input and output
for node in prepare_model.graph.nodes:
if (
node.op == "call_function"
and node.target is torch.ops.aten.max_pool2d_with_indices.default
):
maxpool_node = node
input_obs_of_maxpool = getattr(
prepare_model, maxpool_node.args[0].target
)
elif node.op == "call_function" and node.target is operator.getitem:
output_obs_of_maxpool = getattr(
prepare_model, list(node.users)[0].target
)
elif (
node.op == "call_function"
and node.target is torch.ops.aten.convolution.default
):
conv_node = node
input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target)
self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv)

View File

@ -2,7 +2,8 @@ import copy
import functools
import itertools
import operator
from typing import Any, Dict, List, Optional, Set
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
import torch.nn.functional as F
@ -13,6 +14,12 @@ from torch.ao.quantization.observer import (
)
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_is_annotated,
get_bias_qspec,
@ -28,7 +35,6 @@ from torch.fx.passes.utils.source_matcher_utils import (
get_source_partitions,
SourcePartition,
)
from .quantizer import QuantizationAnnotation, QuantizationSpec, Quantizer
__all__ = [
"X86InductorQuantizer",
@ -36,6 +42,70 @@ __all__ = [
]
@dataclass
class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
# _is_output_of_quantized_pattern:
# * Node as output node of a fusion pattern.
# * The fusion pattern supports int8 data type.
# * The fusion pattern has inputs annotated to insert observer.
_is_output_of_quantized_pattern: bool = False
# Ops support int8 data type and excludes ops like conv, linear.
quantizable_ops_pt2e: Set = {
torch.ops.aten.max_pool2d_with_indices.default,
}
# Ops that:
# 1. Ops prefer to run with int8 when int8 input is given.
# 2. Ops don't support int8 in and fp32 out.
int8_in_int8_out_ops_pt2e: Set = {
torch.ops.aten.max_pool2d_with_indices.default,
}
def _is_node_annotated(_node):
"""
return True if the node is annotated, otherwise return False
"""
return (
"quantization_annotation" in _node.meta
and _node.meta["quantization_annotation"]._annotated
)
def _is_any_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
check if any of the node is annotated, return True if any of the node
is annotated, otherwise return False.
"""
return any(_is_node_annotated(node) for node in nodes)
def _is_all_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
return True if all of the node is annotated, otherwise return False.
"""
return all(_is_node_annotated(node) for node in nodes)
def _is_quantized_op_pt2e(node: torch.fx.Node):
"""
Used for pt2e flow to check if the node is a quantized node:
Case1: the node has been annotated as output node of a fusion pattern.
Case2: the node has been annotated as single quantized node.
"""
if not _is_any_annotated([node]):
# The node has not been annotated, directly return False
return False
quantization_annotation = node.meta.get("quantization_annotation", None)
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
return quantization_annotation._is_output_of_quantized_pattern
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
# TODO: Add more supported operators here.
supported_operators: Dict[str, List[OperatorPatternType]] = {
@ -188,15 +258,21 @@ class X86InductorQuantizer(Quantizer):
if isinstance(bias_node, Node):
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
if annotate_output:
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
conv_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
_is_output_of_quantized_pattern=True,
)
else:
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
conv_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
def _get_output_nodes_of_partitions(
@ -249,16 +325,45 @@ class X86InductorQuantizer(Quantizer):
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
# annotate the nodes from last to first since the matching is in the reversed order
# and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv)
# and we will mark the matched node with "_annoated" so fusion operator pattern
# can take precedence over single operator pattern in this way
r"""
High-level description of quantization recipe for X86 Inductor Backend:
Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively.
Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model
from start to the end. If a pattern supports computation with int8 data type and inputs connected to
quantized patterns, annotate its inputs as quantized pattern.
Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns,
such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type,
we need to annotate the output of this pattern.
"""
config = self.global_config
# Step1: Recipe of fusion patterns like conv/linear.
self._annotate_conv2d_fusion_pattern(model, config)
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
# Go through all the nodes from start to end.
# Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/
# 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538
for node in model.graph.nodes:
self._annotation_propagation_quantizable_pattern(node, config)
# Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized
# in inputs. So, we can fuse dq-operator-q into a quantized op.
# Refer to https://github.com/intel/intel-extension-for-pytorch/blob/
# 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487
for node in model.graph.nodes:
self._annotate_output_for_int8_in_int8_out_pattern(node, config)
return model
def _annotate_conv2d_fusion_pattern(
self, model: torch.fx.GraphModule, config: QuantizationConfig
):
self._annotate_conv2d_binary_unary(model, config)
self._annotate_conv2d_binary(model, config)
self._annotate_conv2d_unary(model, config)
self._annotate_conv2d(model, config)
return model
def _annotate_conv2d_binary_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
@ -293,13 +398,19 @@ class X86InductorQuantizer(Quantizer):
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
quantization_config
)
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map, _annotated=True
binary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
)
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
unary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def _annotate_conv2d_binary(
@ -336,11 +447,14 @@ class X86InductorQuantizer(Quantizer):
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
quantization_config
)
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
binary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def _annotate_conv2d_unary(
@ -362,10 +476,13 @@ class X86InductorQuantizer(Quantizer):
if _is_annotated([unary_node, conv_node]):
continue
self._annotate_conv_node_helper(conv_node, False, quantization_config)
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
unary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def _annotate_conv2d(
@ -389,6 +506,104 @@ class X86InductorQuantizer(Quantizer):
continue
self._annotate_conv_node_helper(conv_node, True, quantization_config)
def _annotate_maxpool2d(
self, node: Node, quantization_config: QuantizationConfig
) -> None:
if node.target is not torch.ops.aten.max_pool2d_with_indices.default or not (
len(list(node.users)) == 1
and (list(node.users)[0].target == operator.getitem)
):
return
maxpool_node = node
getitem_node = list(node.users)[0]
if _is_any_annotated([getitem_node, maxpool_node]):
return
input_node = maxpool_node.args[0]
assert isinstance(input_node, Node)
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
maxpool_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
getitem_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def _annotation_propagation_quantizable_pattern(
self, node: Node, quantization_config: QuantizationConfig
) -> None:
# Propagate annotation to quantizable patterns.
if (
(node.target in quantizable_ops_pt2e)
and (not _is_any_annotated([node]))
and (node.op == "call_function")
):
def is_all_inputs_connected_to_quantized_op(input_nodes):
# Ensure all the inputs connect to fusion pattern or quantized node
for input_node in input_nodes:
if not _is_quantized_op_pt2e(input_node):
return False
return True
if node.target is torch.ops.aten.max_pool2d_with_indices.default:
# Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not
input_nodes_to_check = [node.all_input_nodes[0]]
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
return
self._annotate_maxpool2d(node, quantization_config)
return
else:
# TODO <leslie>: Enable recipes for more single quantizable op such as view and relu.
pass
return
def _annotate_output_for_int8_in_int8_out_pattern(
self, node: Node, quantization_config: QuantizationConfig
) -> None:
r"""
Check and insert observer at output of node in int8_in_int8_out_ops_pt2e if needed.
Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
"""
if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])):
if node.target == torch.ops.aten.max_pool2d_with_indices.default:
maxpool_node = node
assert len(list(maxpool_node.users)) == 1 and (
list(maxpool_node.users)[0].target == operator.getitem
)
getitem_node = list(node.users)[0]
if not _is_all_annotated([getitem_node, maxpool_node]):
return
# Get the quantization_annotation from getitem_node
getitem_quantization_annotation = (
getitem_node.meta["quantization_annotation"]
if "quantization_annotation" in getitem_node.meta
else None
)
if (
getitem_quantization_annotation
and getitem_quantization_annotation._is_output_of_quantized_pattern
):
# Annotate the output_qspec of getitem_node
input_act = maxpool_node.args[0]
assert isinstance(input_act, Node)
assert isinstance(maxpool_node, Node)
edge_or_node: Tuple[Node, Node] = (input_act, maxpool_node)
getitem_node.meta[
"quantization_annotation"
].output_qspec = SharedQuantizationSpec(edge_or_node)
else:
# TODO <leslie>: Enable recipes for more int8_in_int8_out_ops
pass
return
def validate(self, model: torch.fx.GraphModule) -> None:
pass