mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant][PT2E] Enable X86InductorQuantizer single quantizable op(maxpool2d) (#105639)
**Summary** In this PR, we mainly enable 2 things. - Enable the skeleton of quantization recipe for single quantizable operators in `X86InductorQuantizer`. - Add quantization recipe of `maxpool2d` and annotate it as input./output share observer. **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_maxpool2d_recipe ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/105639 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #104580, #104581, #104588, #104590, #105455, #105456
This commit is contained in:
parent
c5ad44be1d
commit
70ca18f8a0
|
|
@ -20,7 +20,8 @@ from torch.testing._internal.common_quantized import override_quantized_engine
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import itertools
|
import itertools
|
||||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||||
|
import operator
|
||||||
|
from torch.ao.quantization import ObserverBase
|
||||||
|
|
||||||
class Conv2DType(Enum):
|
class Conv2DType(Enum):
|
||||||
left = 1
|
left = 1
|
||||||
|
|
@ -127,6 +128,18 @@ class TestHelperModules:
|
||||||
else:
|
else:
|
||||||
return self.relu2(self.conv(x) + self.conv2(x))
|
return self.relu2(self.conv(x) + self.conv2(x))
|
||||||
|
|
||||||
|
class Conv2dMaxpoolPowModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, 1)
|
||||||
|
self.pool = nn.MaxPool2d(1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
return torch.pow(x, 2)
|
||||||
|
|
||||||
|
|
||||||
class SerialsConv2dAddReLUModule(torch.nn.Module):
|
class SerialsConv2dAddReLUModule(torch.nn.Module):
|
||||||
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
|
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
|
||||||
"""
|
"""
|
||||||
|
|
@ -171,10 +184,13 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||||
*copy.deepcopy(example_inputs),
|
*copy.deepcopy(example_inputs),
|
||||||
aten_graph=True,
|
aten_graph=True,
|
||||||
)
|
)
|
||||||
|
export_model = copy.deepcopy(m)
|
||||||
m = prepare_pt2e(m, quantizer)
|
m = prepare_pt2e(m, quantizer)
|
||||||
# Calibrate
|
# Calibrate
|
||||||
m(*example_inputs)
|
m(*example_inputs)
|
||||||
|
prepare_model = copy.deepcopy(m)
|
||||||
m = convert_pt2e(m)
|
m = convert_pt2e(m)
|
||||||
|
convert_model = copy.deepcopy(m)
|
||||||
pt2_quant_output = m(*example_inputs)
|
pt2_quant_output = m(*example_inputs)
|
||||||
node_occurrence = {
|
node_occurrence = {
|
||||||
ns.call_function(k): v for k, v in expected_node_occurrence.items()
|
ns.call_function(k): v for k, v in expected_node_occurrence.items()
|
||||||
|
|
@ -185,6 +201,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||||
self.checkGraphModuleNodes(
|
self.checkGraphModuleNodes(
|
||||||
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
|
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
|
||||||
)
|
)
|
||||||
|
return export_model, prepare_model, convert_model
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
|
|
@ -400,3 +417,65 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
node_occurrence,
|
node_occurrence,
|
||||||
node_list,
|
node_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoX86
|
||||||
|
def test_maxpool2d_recipe(self):
|
||||||
|
r"""
|
||||||
|
Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow)
|
||||||
|
Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow.
|
||||||
|
"""
|
||||||
|
m = TestHelperModules.Conv2dMaxpoolPowModule().eval()
|
||||||
|
x = torch.rand(1, 2, 14, 14)
|
||||||
|
quantizer = X86InductorQuantizer().set_global(
|
||||||
|
xiq.get_default_x86_inductor_quantization_config()
|
||||||
|
)
|
||||||
|
example_inputs = (x,)
|
||||||
|
node_occurrence = {
|
||||||
|
# one for input and weight of the conv, two for input/output for the maxpool2d
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||||
|
}
|
||||||
|
node_list = [
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
torch.ops.aten.convolution.default,
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
torch.ops.aten.max_pool2d_with_indices.default,
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
]
|
||||||
|
_, prepare_model, _ = self._test_quantizer(
|
||||||
|
m,
|
||||||
|
example_inputs,
|
||||||
|
quantizer,
|
||||||
|
node_occurrence,
|
||||||
|
node_list,
|
||||||
|
)
|
||||||
|
# Check Maxpool2d has share observer at input and output
|
||||||
|
for node in prepare_model.graph.nodes:
|
||||||
|
if (
|
||||||
|
node.op == "call_function"
|
||||||
|
and node.target is torch.ops.aten.max_pool2d_with_indices.default
|
||||||
|
):
|
||||||
|
maxpool_node = node
|
||||||
|
input_obs_of_maxpool = getattr(
|
||||||
|
prepare_model, maxpool_node.args[0].target
|
||||||
|
)
|
||||||
|
elif node.op == "call_function" and node.target is operator.getitem:
|
||||||
|
output_obs_of_maxpool = getattr(
|
||||||
|
prepare_model, list(node.users)[0].target
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
node.op == "call_function"
|
||||||
|
and node.target is torch.ops.aten.convolution.default
|
||||||
|
):
|
||||||
|
conv_node = node
|
||||||
|
input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target)
|
||||||
|
self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
|
||||||
|
self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
|
||||||
|
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
|
||||||
|
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
|
||||||
|
self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ import copy
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -13,6 +14,12 @@ from torch.ao.quantization.observer import (
|
||||||
)
|
)
|
||||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||||
|
from torch.ao.quantization.quantizer.quantizer import (
|
||||||
|
QuantizationAnnotation,
|
||||||
|
QuantizationSpec,
|
||||||
|
Quantizer,
|
||||||
|
SharedQuantizationSpec,
|
||||||
|
)
|
||||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||||
_is_annotated,
|
_is_annotated,
|
||||||
get_bias_qspec,
|
get_bias_qspec,
|
||||||
|
|
@ -28,7 +35,6 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
||||||
get_source_partitions,
|
get_source_partitions,
|
||||||
SourcePartition,
|
SourcePartition,
|
||||||
)
|
)
|
||||||
from .quantizer import QuantizationAnnotation, QuantizationSpec, Quantizer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"X86InductorQuantizer",
|
"X86InductorQuantizer",
|
||||||
|
|
@ -36,6 +42,70 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
|
||||||
|
# _is_output_of_quantized_pattern:
|
||||||
|
# * Node as output node of a fusion pattern.
|
||||||
|
# * The fusion pattern supports int8 data type.
|
||||||
|
# * The fusion pattern has inputs annotated to insert observer.
|
||||||
|
_is_output_of_quantized_pattern: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# Ops support int8 data type and excludes ops like conv, linear.
|
||||||
|
quantizable_ops_pt2e: Set = {
|
||||||
|
torch.ops.aten.max_pool2d_with_indices.default,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Ops that:
|
||||||
|
# 1. Ops prefer to run with int8 when int8 input is given.
|
||||||
|
# 2. Ops don't support int8 in and fp32 out.
|
||||||
|
int8_in_int8_out_ops_pt2e: Set = {
|
||||||
|
torch.ops.aten.max_pool2d_with_indices.default,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_node_annotated(_node):
|
||||||
|
"""
|
||||||
|
return True if the node is annotated, otherwise return False
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"quantization_annotation" in _node.meta
|
||||||
|
and _node.meta["quantization_annotation"]._annotated
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_any_annotated(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
Given a list of nodes (that represents an operator pattern),
|
||||||
|
check if any of the node is annotated, return True if any of the node
|
||||||
|
is annotated, otherwise return False.
|
||||||
|
"""
|
||||||
|
return any(_is_node_annotated(node) for node in nodes)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_all_annotated(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
Given a list of nodes (that represents an operator pattern),
|
||||||
|
return True if all of the node is annotated, otherwise return False.
|
||||||
|
"""
|
||||||
|
return all(_is_node_annotated(node) for node in nodes)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_quantized_op_pt2e(node: torch.fx.Node):
|
||||||
|
"""
|
||||||
|
Used for pt2e flow to check if the node is a quantized node:
|
||||||
|
Case1: the node has been annotated as output node of a fusion pattern.
|
||||||
|
Case2: the node has been annotated as single quantized node.
|
||||||
|
"""
|
||||||
|
if not _is_any_annotated([node]):
|
||||||
|
# The node has not been annotated, directly return False
|
||||||
|
return False
|
||||||
|
quantization_annotation = node.meta.get("quantization_annotation", None)
|
||||||
|
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
|
||||||
|
return quantization_annotation._is_output_of_quantized_pattern
|
||||||
|
|
||||||
|
|
||||||
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||||
# TODO: Add more supported operators here.
|
# TODO: Add more supported operators here.
|
||||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||||
|
|
@ -188,15 +258,21 @@ class X86InductorQuantizer(Quantizer):
|
||||||
if isinstance(bias_node, Node):
|
if isinstance(bias_node, Node):
|
||||||
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
|
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
|
||||||
if annotate_output:
|
if annotate_output:
|
||||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
conv_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
input_qspec_map=input_qspec_map,
|
input_qspec_map=input_qspec_map,
|
||||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||||
output_qspec=get_output_act_qspec(quantization_config),
|
output_qspec=get_output_act_qspec(quantization_config),
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
|
_is_output_of_quantized_pattern=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
conv_node.meta[
|
||||||
input_qspec_map=input_qspec_map, _annotated=True
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
|
input_qspec_map=input_qspec_map,
|
||||||
|
_annotated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_output_nodes_of_partitions(
|
def _get_output_nodes_of_partitions(
|
||||||
|
|
@ -249,16 +325,45 @@ class X86InductorQuantizer(Quantizer):
|
||||||
def _annotate_for_static_quantization_config(
|
def _annotate_for_static_quantization_config(
|
||||||
self, model: torch.fx.GraphModule
|
self, model: torch.fx.GraphModule
|
||||||
) -> torch.fx.GraphModule:
|
) -> torch.fx.GraphModule:
|
||||||
# annotate the nodes from last to first since the matching is in the reversed order
|
r"""
|
||||||
# and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv)
|
High-level description of quantization recipe for X86 Inductor Backend:
|
||||||
# and we will mark the matched node with "_annoated" so fusion operator pattern
|
Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively.
|
||||||
# can take precedence over single operator pattern in this way
|
Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model
|
||||||
|
from start to the end. If a pattern supports computation with int8 data type and inputs connected to
|
||||||
|
quantized patterns, annotate its inputs as quantized pattern.
|
||||||
|
Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns,
|
||||||
|
such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type,
|
||||||
|
we need to annotate the output of this pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
config = self.global_config
|
config = self.global_config
|
||||||
|
|
||||||
|
# Step1: Recipe of fusion patterns like conv/linear.
|
||||||
|
self._annotate_conv2d_fusion_pattern(model, config)
|
||||||
|
|
||||||
|
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
|
||||||
|
# Go through all the nodes from start to end.
|
||||||
|
# Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||||
|
# 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538
|
||||||
|
for node in model.graph.nodes:
|
||||||
|
self._annotation_propagation_quantizable_pattern(node, config)
|
||||||
|
|
||||||
|
# Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized
|
||||||
|
# in inputs. So, we can fuse dq-operator-q into a quantized op.
|
||||||
|
# Refer to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||||
|
# 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487
|
||||||
|
for node in model.graph.nodes:
|
||||||
|
self._annotate_output_for_int8_in_int8_out_pattern(node, config)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _annotate_conv2d_fusion_pattern(
|
||||||
|
self, model: torch.fx.GraphModule, config: QuantizationConfig
|
||||||
|
):
|
||||||
self._annotate_conv2d_binary_unary(model, config)
|
self._annotate_conv2d_binary_unary(model, config)
|
||||||
self._annotate_conv2d_binary(model, config)
|
self._annotate_conv2d_binary(model, config)
|
||||||
self._annotate_conv2d_unary(model, config)
|
self._annotate_conv2d_unary(model, config)
|
||||||
self._annotate_conv2d(model, config)
|
self._annotate_conv2d(model, config)
|
||||||
return model
|
|
||||||
|
|
||||||
def _annotate_conv2d_binary_unary(
|
def _annotate_conv2d_binary_unary(
|
||||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||||
|
|
@ -293,13 +398,19 @@ class X86InductorQuantizer(Quantizer):
|
||||||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||||
quantization_config
|
quantization_config
|
||||||
)
|
)
|
||||||
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
binary_node.meta[
|
||||||
input_qspec_map=binary_node_input_qspec_map, _annotated=True
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
|
_annotated=True,
|
||||||
)
|
)
|
||||||
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
unary_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
|
_is_output_of_quantized_pattern=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _annotate_conv2d_binary(
|
def _annotate_conv2d_binary(
|
||||||
|
|
@ -336,11 +447,14 @@ class X86InductorQuantizer(Quantizer):
|
||||||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||||
quantization_config
|
quantization_config
|
||||||
)
|
)
|
||||||
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
binary_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
|
_is_output_of_quantized_pattern=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _annotate_conv2d_unary(
|
def _annotate_conv2d_unary(
|
||||||
|
|
@ -362,10 +476,13 @@ class X86InductorQuantizer(Quantizer):
|
||||||
if _is_annotated([unary_node, conv_node]):
|
if _is_annotated([unary_node, conv_node]):
|
||||||
continue
|
continue
|
||||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||||
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
unary_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
|
_is_output_of_quantized_pattern=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _annotate_conv2d(
|
def _annotate_conv2d(
|
||||||
|
|
@ -389,6 +506,104 @@ class X86InductorQuantizer(Quantizer):
|
||||||
continue
|
continue
|
||||||
self._annotate_conv_node_helper(conv_node, True, quantization_config)
|
self._annotate_conv_node_helper(conv_node, True, quantization_config)
|
||||||
|
|
||||||
|
def _annotate_maxpool2d(
|
||||||
|
self, node: Node, quantization_config: QuantizationConfig
|
||||||
|
) -> None:
|
||||||
|
if node.target is not torch.ops.aten.max_pool2d_with_indices.default or not (
|
||||||
|
len(list(node.users)) == 1
|
||||||
|
and (list(node.users)[0].target == operator.getitem)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
maxpool_node = node
|
||||||
|
getitem_node = list(node.users)[0]
|
||||||
|
if _is_any_annotated([getitem_node, maxpool_node]):
|
||||||
|
return
|
||||||
|
input_node = maxpool_node.args[0]
|
||||||
|
assert isinstance(input_node, Node)
|
||||||
|
input_qspec_map = {}
|
||||||
|
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||||
|
maxpool_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
|
input_qspec_map=input_qspec_map,
|
||||||
|
_annotated=True,
|
||||||
|
)
|
||||||
|
getitem_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
|
_annotated=True,
|
||||||
|
_is_output_of_quantized_pattern=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _annotation_propagation_quantizable_pattern(
|
||||||
|
self, node: Node, quantization_config: QuantizationConfig
|
||||||
|
) -> None:
|
||||||
|
# Propagate annotation to quantizable patterns.
|
||||||
|
if (
|
||||||
|
(node.target in quantizable_ops_pt2e)
|
||||||
|
and (not _is_any_annotated([node]))
|
||||||
|
and (node.op == "call_function")
|
||||||
|
):
|
||||||
|
|
||||||
|
def is_all_inputs_connected_to_quantized_op(input_nodes):
|
||||||
|
# Ensure all the inputs connect to fusion pattern or quantized node
|
||||||
|
for input_node in input_nodes:
|
||||||
|
if not _is_quantized_op_pt2e(input_node):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
if node.target is torch.ops.aten.max_pool2d_with_indices.default:
|
||||||
|
# Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not
|
||||||
|
input_nodes_to_check = [node.all_input_nodes[0]]
|
||||||
|
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
|
||||||
|
return
|
||||||
|
self._annotate_maxpool2d(node, quantization_config)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# TODO <leslie>: Enable recipes for more single quantizable op such as view and relu.
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
def _annotate_output_for_int8_in_int8_out_pattern(
|
||||||
|
self, node: Node, quantization_config: QuantizationConfig
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Check and insert observer at output of node in int8_in_int8_out_ops_pt2e if needed.
|
||||||
|
Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||||
|
90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
|
||||||
|
"""
|
||||||
|
if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])):
|
||||||
|
if node.target == torch.ops.aten.max_pool2d_with_indices.default:
|
||||||
|
maxpool_node = node
|
||||||
|
assert len(list(maxpool_node.users)) == 1 and (
|
||||||
|
list(maxpool_node.users)[0].target == operator.getitem
|
||||||
|
)
|
||||||
|
getitem_node = list(node.users)[0]
|
||||||
|
if not _is_all_annotated([getitem_node, maxpool_node]):
|
||||||
|
return
|
||||||
|
# Get the quantization_annotation from getitem_node
|
||||||
|
getitem_quantization_annotation = (
|
||||||
|
getitem_node.meta["quantization_annotation"]
|
||||||
|
if "quantization_annotation" in getitem_node.meta
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
getitem_quantization_annotation
|
||||||
|
and getitem_quantization_annotation._is_output_of_quantized_pattern
|
||||||
|
):
|
||||||
|
# Annotate the output_qspec of getitem_node
|
||||||
|
input_act = maxpool_node.args[0]
|
||||||
|
assert isinstance(input_act, Node)
|
||||||
|
assert isinstance(maxpool_node, Node)
|
||||||
|
edge_or_node: Tuple[Node, Node] = (input_act, maxpool_node)
|
||||||
|
getitem_node.meta[
|
||||||
|
"quantization_annotation"
|
||||||
|
].output_qspec = SharedQuantizationSpec(edge_or_node)
|
||||||
|
else:
|
||||||
|
# TODO <leslie>: Enable recipes for more int8_in_int8_out_ops
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user