mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant][PT2E] Enable X86InductorQuantizer single quantizable op(maxpool2d) (#105639)
**Summary** In this PR, we mainly enable 2 things. - Enable the skeleton of quantization recipe for single quantizable operators in `X86InductorQuantizer`. - Add quantization recipe of `maxpool2d` and annotate it as input./output share observer. **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_maxpool2d_recipe ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/105639 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456
This commit is contained in:
parent
c5ad44be1d
commit
70ca18f8a0
|
|
@ -20,7 +20,8 @@ from torch.testing._internal.common_quantized import override_quantized_engine
|
|||
from enum import Enum
|
||||
import itertools
|
||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
|
||||
import operator
|
||||
from torch.ao.quantization import ObserverBase
|
||||
|
||||
class Conv2DType(Enum):
|
||||
left = 1
|
||||
|
|
@ -127,6 +128,18 @@ class TestHelperModules:
|
|||
else:
|
||||
return self.relu2(self.conv(x) + self.conv2(x))
|
||||
|
||||
class Conv2dMaxpoolPowModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 2, 1)
|
||||
self.pool = nn.MaxPool2d(1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
return torch.pow(x, 2)
|
||||
|
||||
|
||||
class SerialsConv2dAddReLUModule(torch.nn.Module):
|
||||
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
|
||||
"""
|
||||
|
|
@ -171,10 +184,13 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
|||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
)
|
||||
export_model = copy.deepcopy(m)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
prepare_model = copy.deepcopy(m)
|
||||
m = convert_pt2e(m)
|
||||
convert_model = copy.deepcopy(m)
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
node_occurrence = {
|
||||
ns.call_function(k): v for k, v in expected_node_occurrence.items()
|
||||
|
|
@ -185,6 +201,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
|||
self.checkGraphModuleNodes(
|
||||
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
|
||||
)
|
||||
return export_model, prepare_model, convert_model
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
|
|
@ -400,3 +417,65 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_maxpool2d_recipe(self):
|
||||
r"""
|
||||
Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow)
|
||||
Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow.
|
||||
"""
|
||||
m = TestHelperModules.Conv2dMaxpoolPowModule().eval()
|
||||
x = torch.rand(1, 2, 14, 14)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
example_inputs = (x,)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, two for input/output for the maxpool2d
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.max_pool2d_with_indices.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
_, prepare_model, _ = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
# Check Maxpool2d has share observer at input and output
|
||||
for node in prepare_model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops.aten.max_pool2d_with_indices.default
|
||||
):
|
||||
maxpool_node = node
|
||||
input_obs_of_maxpool = getattr(
|
||||
prepare_model, maxpool_node.args[0].target
|
||||
)
|
||||
elif node.op == "call_function" and node.target is operator.getitem:
|
||||
output_obs_of_maxpool = getattr(
|
||||
prepare_model, list(node.users)[0].target
|
||||
)
|
||||
elif (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops.aten.convolution.default
|
||||
):
|
||||
conv_node = node
|
||||
input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target)
|
||||
self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
|
||||
self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
|
||||
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
|
||||
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
|
||||
self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import copy
|
|||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -13,6 +14,12 @@ from torch.ao.quantization.observer import (
|
|||
)
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from torch.ao.quantization.quantizer.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
QuantizationSpec,
|
||||
Quantizer,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
_is_annotated,
|
||||
get_bias_qspec,
|
||||
|
|
@ -28,7 +35,6 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
|||
get_source_partitions,
|
||||
SourcePartition,
|
||||
)
|
||||
from .quantizer import QuantizationAnnotation, QuantizationSpec, Quantizer
|
||||
|
||||
__all__ = [
|
||||
"X86InductorQuantizer",
|
||||
|
|
@ -36,6 +42,70 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
|
||||
# _is_output_of_quantized_pattern:
|
||||
# * Node as output node of a fusion pattern.
|
||||
# * The fusion pattern supports int8 data type.
|
||||
# * The fusion pattern has inputs annotated to insert observer.
|
||||
_is_output_of_quantized_pattern: bool = False
|
||||
|
||||
|
||||
# Ops support int8 data type and excludes ops like conv, linear.
|
||||
quantizable_ops_pt2e: Set = {
|
||||
torch.ops.aten.max_pool2d_with_indices.default,
|
||||
}
|
||||
|
||||
|
||||
# Ops that:
|
||||
# 1. Ops prefer to run with int8 when int8 input is given.
|
||||
# 2. Ops don't support int8 in and fp32 out.
|
||||
int8_in_int8_out_ops_pt2e: Set = {
|
||||
torch.ops.aten.max_pool2d_with_indices.default,
|
||||
}
|
||||
|
||||
|
||||
def _is_node_annotated(_node):
|
||||
"""
|
||||
return True if the node is annotated, otherwise return False
|
||||
"""
|
||||
return (
|
||||
"quantization_annotation" in _node.meta
|
||||
and _node.meta["quantization_annotation"]._annotated
|
||||
)
|
||||
|
||||
|
||||
def _is_any_annotated(nodes: List[Node]):
|
||||
"""
|
||||
Given a list of nodes (that represents an operator pattern),
|
||||
check if any of the node is annotated, return True if any of the node
|
||||
is annotated, otherwise return False.
|
||||
"""
|
||||
return any(_is_node_annotated(node) for node in nodes)
|
||||
|
||||
|
||||
def _is_all_annotated(nodes: List[Node]):
|
||||
"""
|
||||
Given a list of nodes (that represents an operator pattern),
|
||||
return True if all of the node is annotated, otherwise return False.
|
||||
"""
|
||||
return all(_is_node_annotated(node) for node in nodes)
|
||||
|
||||
|
||||
def _is_quantized_op_pt2e(node: torch.fx.Node):
|
||||
"""
|
||||
Used for pt2e flow to check if the node is a quantized node:
|
||||
Case1: the node has been annotated as output node of a fusion pattern.
|
||||
Case2: the node has been annotated as single quantized node.
|
||||
"""
|
||||
if not _is_any_annotated([node]):
|
||||
# The node has not been annotated, directly return False
|
||||
return False
|
||||
quantization_annotation = node.meta.get("quantization_annotation", None)
|
||||
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
|
||||
return quantization_annotation._is_output_of_quantized_pattern
|
||||
|
||||
|
||||
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
# TODO: Add more supported operators here.
|
||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||
|
|
@ -188,15 +258,21 @@ class X86InductorQuantizer(Quantizer):
|
|||
if isinstance(bias_node, Node):
|
||||
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
|
||||
if annotate_output:
|
||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
conv_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config),
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
else:
|
||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map, _annotated=True
|
||||
conv_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _get_output_nodes_of_partitions(
|
||||
|
|
@ -249,16 +325,45 @@ class X86InductorQuantizer(Quantizer):
|
|||
def _annotate_for_static_quantization_config(
|
||||
self, model: torch.fx.GraphModule
|
||||
) -> torch.fx.GraphModule:
|
||||
# annotate the nodes from last to first since the matching is in the reversed order
|
||||
# and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv)
|
||||
# and we will mark the matched node with "_annoated" so fusion operator pattern
|
||||
# can take precedence over single operator pattern in this way
|
||||
r"""
|
||||
High-level description of quantization recipe for X86 Inductor Backend:
|
||||
Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively.
|
||||
Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model
|
||||
from start to the end. If a pattern supports computation with int8 data type and inputs connected to
|
||||
quantized patterns, annotate its inputs as quantized pattern.
|
||||
Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns,
|
||||
such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type,
|
||||
we need to annotate the output of this pattern.
|
||||
"""
|
||||
|
||||
config = self.global_config
|
||||
|
||||
# Step1: Recipe of fusion patterns like conv/linear.
|
||||
self._annotate_conv2d_fusion_pattern(model, config)
|
||||
|
||||
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
|
||||
# Go through all the nodes from start to end.
|
||||
# Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||
# 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538
|
||||
for node in model.graph.nodes:
|
||||
self._annotation_propagation_quantizable_pattern(node, config)
|
||||
|
||||
# Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized
|
||||
# in inputs. So, we can fuse dq-operator-q into a quantized op.
|
||||
# Refer to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||
# 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487
|
||||
for node in model.graph.nodes:
|
||||
self._annotate_output_for_int8_in_int8_out_pattern(node, config)
|
||||
|
||||
return model
|
||||
|
||||
def _annotate_conv2d_fusion_pattern(
|
||||
self, model: torch.fx.GraphModule, config: QuantizationConfig
|
||||
):
|
||||
self._annotate_conv2d_binary_unary(model, config)
|
||||
self._annotate_conv2d_binary(model, config)
|
||||
self._annotate_conv2d_unary(model, config)
|
||||
self._annotate_conv2d(model, config)
|
||||
return model
|
||||
|
||||
def _annotate_conv2d_binary_unary(
|
||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||
|
|
@ -293,13 +398,19 @@ class X86InductorQuantizer(Quantizer):
|
|||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||
quantization_config
|
||||
)
|
||||
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map, _annotated=True
|
||||
binary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
unary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d_binary(
|
||||
|
|
@ -336,11 +447,14 @@ class X86InductorQuantizer(Quantizer):
|
|||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||
quantization_config
|
||||
)
|
||||
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
binary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d_unary(
|
||||
|
|
@ -362,10 +476,13 @@ class X86InductorQuantizer(Quantizer):
|
|||
if _is_annotated([unary_node, conv_node]):
|
||||
continue
|
||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
unary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d(
|
||||
|
|
@ -389,6 +506,104 @@ class X86InductorQuantizer(Quantizer):
|
|||
continue
|
||||
self._annotate_conv_node_helper(conv_node, True, quantization_config)
|
||||
|
||||
def _annotate_maxpool2d(
|
||||
self, node: Node, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
if node.target is not torch.ops.aten.max_pool2d_with_indices.default or not (
|
||||
len(list(node.users)) == 1
|
||||
and (list(node.users)[0].target == operator.getitem)
|
||||
):
|
||||
return
|
||||
maxpool_node = node
|
||||
getitem_node = list(node.users)[0]
|
||||
if _is_any_annotated([getitem_node, maxpool_node]):
|
||||
return
|
||||
input_node = maxpool_node.args[0]
|
||||
assert isinstance(input_node, Node)
|
||||
input_qspec_map = {}
|
||||
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||
maxpool_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
getitem_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotation_propagation_quantizable_pattern(
|
||||
self, node: Node, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
# Propagate annotation to quantizable patterns.
|
||||
if (
|
||||
(node.target in quantizable_ops_pt2e)
|
||||
and (not _is_any_annotated([node]))
|
||||
and (node.op == "call_function")
|
||||
):
|
||||
|
||||
def is_all_inputs_connected_to_quantized_op(input_nodes):
|
||||
# Ensure all the inputs connect to fusion pattern or quantized node
|
||||
for input_node in input_nodes:
|
||||
if not _is_quantized_op_pt2e(input_node):
|
||||
return False
|
||||
return True
|
||||
|
||||
if node.target is torch.ops.aten.max_pool2d_with_indices.default:
|
||||
# Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not
|
||||
input_nodes_to_check = [node.all_input_nodes[0]]
|
||||
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
|
||||
return
|
||||
self._annotate_maxpool2d(node, quantization_config)
|
||||
return
|
||||
else:
|
||||
# TODO <leslie>: Enable recipes for more single quantizable op such as view and relu.
|
||||
pass
|
||||
return
|
||||
|
||||
def _annotate_output_for_int8_in_int8_out_pattern(
|
||||
self, node: Node, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
r"""
|
||||
Check and insert observer at output of node in int8_in_int8_out_ops_pt2e if needed.
|
||||
Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||
90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
|
||||
"""
|
||||
if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])):
|
||||
if node.target == torch.ops.aten.max_pool2d_with_indices.default:
|
||||
maxpool_node = node
|
||||
assert len(list(maxpool_node.users)) == 1 and (
|
||||
list(maxpool_node.users)[0].target == operator.getitem
|
||||
)
|
||||
getitem_node = list(node.users)[0]
|
||||
if not _is_all_annotated([getitem_node, maxpool_node]):
|
||||
return
|
||||
# Get the quantization_annotation from getitem_node
|
||||
getitem_quantization_annotation = (
|
||||
getitem_node.meta["quantization_annotation"]
|
||||
if "quantization_annotation" in getitem_node.meta
|
||||
else None
|
||||
)
|
||||
if (
|
||||
getitem_quantization_annotation
|
||||
and getitem_quantization_annotation._is_output_of_quantized_pattern
|
||||
):
|
||||
# Annotate the output_qspec of getitem_node
|
||||
input_act = maxpool_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
assert isinstance(maxpool_node, Node)
|
||||
edge_or_node: Tuple[Node, Node] = (input_act, maxpool_node)
|
||||
getitem_node.meta[
|
||||
"quantization_annotation"
|
||||
].output_qspec = SharedQuantizationSpec(edge_or_node)
|
||||
else:
|
||||
# TODO <leslie>: Enable recipes for more int8_in_int8_out_ops
|
||||
pass
|
||||
return
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user