[Quant][PT2E] Add cat and avg_pool2d recipe into x86InductorQuantizer (#106836)

**Summary**
Add `cat` and `avg_pool2d` quantization recipe as input output share observer into `x86InductorQuantizer`.

**Test Plan**
```
clear && python -m pytest test_x86inductor_quantizer.py -k test_cat_recipe
clear && python -m pytest test_x86inductor_quantizer.py -k test_cat_recipe_same_inputs
clear && python -m pytest test_x86inductor_quantizer.py -k test_cat_recipe_single_input
clear && python -m pytest test_x86inductor_quantizer.py -k test_avg_pool2d_recipe
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106836
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
leslie-fang-intel 2023-08-26 16:50:08 +08:00 committed by PyTorch MergeBot
parent 15d4dedbbf
commit 1147a28b0b
2 changed files with 382 additions and 39 deletions

View File

@ -139,7 +139,6 @@ class TestHelperModules:
x = self.pool(x)
return torch.pow(x, 2)
class SerialsConv2dAddReLUModule(torch.nn.Module):
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
"""
@ -166,6 +165,55 @@ class TestHelperModules:
res2 = self.relu2(self.conv4(res1) + res1)
return res2
class Conv2dCatMaxpool2d(torch.nn.Module):
def __init__(self,):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
self.conv2 = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(3, stride=2, padding=1)
self.conv3 = torch.nn.Conv2d(32, 32, 7, bias=True, stride=2, padding=3, dilation=1)
def forward(self, x):
temp1 = self.relu(self.conv(x))
temp2 = self.conv2(x + 1)
temp3 = torch.cat((temp1, temp2), 1)
temp4 = self.maxpool(temp3)
temp5 = self.conv3(temp4)
return temp5
class Conv2dAvgPool2d(torch.nn.Module):
def __init__(self,):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
self.avgpool = torch.nn.AvgPool2d(3, stride=2, padding=1)
def forward(self, x):
temp1 = self.avgpool(self.conv(x))
return temp1
class Conv2dCatSameInputs(torch.nn.Module):
def __init__(self,):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
self.relu = torch.nn.ReLU()
def forward(self, x):
temp1 = self.relu(self.conv(x))
temp3 = torch.cat((temp1, temp1), 1)
return temp3
class Conv2dCatSingleInput(torch.nn.Module):
def __init__(self,):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
self.relu = torch.nn.ReLU()
def forward(self, x):
temp1 = self.relu(self.conv(x))
temp3 = torch.cat((temp1,), 1)
return temp3
class X86InductorQuantTestCase(QuantizationTestCase):
def _test_quantizer(
self,
@ -479,3 +527,244 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv)
@skipIfNoX86
def test_cat_recipe(self):
r"""
Test pattern: conv -> cat -> maxpool2d
Since cat, maxpool is a int8_in_int8_out_op, the inputs and outputs should with same observer.
"""
m = TestHelperModules.Conv2dCatMaxpool2d().eval()
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
example_inputs = (x,)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 7,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.cat.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.max_pool2d_with_indices.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
_, prepare_model, _ = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
# Check Cat/Maxpool2d has share observer at input and output
for node in prepare_model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.cat.default
):
cat_act_obs0 = getattr(
prepare_model, node.all_input_nodes[0].target
)
cat_act_obs1 = getattr(
prepare_model, node.all_input_nodes[1].target
)
cat_out_obs = getattr(
prepare_model, list(node.users)[0].target
)
elif (
node.op == "call_function"
and node.target is torch.ops.aten.max_pool2d_with_indices.default
):
maxpool_node = node
input_obs_of_maxpool = getattr(
prepare_model, maxpool_node.args[0].target
)
elif node.op == "call_function" and node.target is operator.getitem:
output_obs_of_maxpool = getattr(
prepare_model, list(node.users)[0].target
)
self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
self.assertTrue(isinstance(cat_act_obs1, ObserverBase))
self.assertTrue(isinstance(cat_out_obs, ObserverBase))
self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
self.assertTrue(cat_act_obs0 is cat_act_obs1)
self.assertTrue(cat_act_obs0 is cat_out_obs)
self.assertTrue(cat_out_obs is input_obs_of_maxpool)
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
@skipIfNoX86
def test_cat_recipe_same_inputs(self):
r"""
Test pattern: conv -> cat([input0, input0])
Since cat has 2 input node of same tensor, they should also be with same observer.
"""
m = TestHelperModules.Conv2dCatSameInputs().eval()
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
example_inputs = (x,)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.cat.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
_, prepare_model, _ = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
# Check Cat has share observer at input and output
for node in prepare_model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.cat.default
):
cat_act_obs0 = getattr(
prepare_model, node.args[0][0].target
)
cat_act_obs1 = getattr(
prepare_model, node.args[0][1].target
)
cat_out_obs = getattr(
prepare_model, list(node.users)[0].target
)
self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
self.assertTrue(isinstance(cat_act_obs1, ObserverBase))
self.assertTrue(isinstance(cat_out_obs, ObserverBase))
self.assertTrue(cat_act_obs0 is cat_act_obs1)
self.assertTrue(cat_act_obs0 is cat_out_obs)
@skipIfNoX86
def test_cat_recipe_single_input(self):
r"""
Test pattern: conv -> cat([input0,])
Since cat has 1 input node, they should also be with same observer.
"""
m = TestHelperModules.Conv2dCatSingleInput().eval()
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
example_inputs = (x,)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.cat.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
_, prepare_model, _ = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
# Check Cat has share observer at input and output
for node in prepare_model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.cat.default
):
cat_act_obs0 = getattr(
prepare_model, node.args[0][0].target
)
cat_out_obs = getattr(
prepare_model, list(node.users)[0].target
)
self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
self.assertTrue(isinstance(cat_out_obs, ObserverBase))
self.assertTrue(cat_act_obs0 is cat_out_obs)
@skipIfNoX86
def test_avg_pool2d_recipe(self):
r"""
Test pattern: conv -> AvgPool2d
Since AvgPool2d is a int8_in_int8_out_op, the inputs and outputs should with same observer.
"""
m = TestHelperModules.Conv2dAvgPool2d().eval()
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
example_inputs = (x,)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
_, prepare_model, _ = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
for node in prepare_model.graph.nodes:
if (
node.op == "call_function"
and node.target is torch.ops.aten.avg_pool2d.default
):
avgpool_node = node
input_obs_of_avgpool = getattr(
prepare_model, avgpool_node.args[0].target
)
output_obs_of_avgpool = getattr(
prepare_model, list(avgpool_node.users)[0].target
)
elif (
node.op == "call_function"
and node.target is torch.ops.aten.convolution.default
):
conv_node = node
output_obs_of_conv = getattr(prepare_model, list(conv_node.users)[0].target)
self.assertTrue(isinstance(input_obs_of_avgpool, ObserverBase))
self.assertTrue(isinstance(output_obs_of_avgpool, ObserverBase))
self.assertTrue(isinstance(output_obs_of_conv, ObserverBase))
self.assertTrue(input_obs_of_avgpool is output_obs_of_avgpool)
self.assertTrue(input_obs_of_avgpool is output_obs_of_conv)

View File

@ -3,7 +3,7 @@ import functools
import itertools
import operator
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
import torch
import torch.nn.functional as F
@ -54,6 +54,8 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
# Ops support int8 data type and excludes ops like conv, linear.
quantizable_ops_pt2e: Set = {
torch.ops.aten.max_pool2d_with_indices.default,
torch.ops.aten.cat.default,
torch.ops.aten.avg_pool2d.default,
}
@ -62,16 +64,21 @@ quantizable_ops_pt2e: Set = {
# 2. Ops don't support int8 in and fp32 out.
int8_in_int8_out_ops_pt2e: Set = {
torch.ops.aten.max_pool2d_with_indices.default,
torch.ops.aten.cat.default,
torch.ops.aten.avg_pool2d.default,
}
QUANT_ANNOTATION_KEY = "quantization_annotation"
def _is_node_annotated(_node):
"""
return True if the node is annotated, otherwise return False
"""
return (
"quantization_annotation" in _node.meta
and _node.meta["quantization_annotation"]._annotated
QUANT_ANNOTATION_KEY in _node.meta
and _node.meta[QUANT_ANNOTATION_KEY]._annotated
)
@ -101,7 +108,7 @@ def _is_quantized_op_pt2e(node: torch.fx.Node):
if not _is_any_annotated([node]):
# The node has not been annotated, directly return False
return False
quantization_annotation = node.meta.get("quantization_annotation", None)
quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
return quantization_annotation._is_output_of_quantized_pattern
@ -258,9 +265,7 @@ class X86InductorQuantizer(Quantizer):
if isinstance(bias_node, Node):
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
if annotate_output:
conv_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config),
@ -268,9 +273,7 @@ class X86InductorQuantizer(Quantizer):
_is_output_of_quantized_pattern=True,
)
else:
conv_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
@ -398,15 +401,11 @@ class X86InductorQuantizer(Quantizer):
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
quantization_config
)
binary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
)
unary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
@ -447,9 +446,7 @@ class X86InductorQuantizer(Quantizer):
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
quantization_config
)
binary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
@ -476,9 +473,7 @@ class X86InductorQuantizer(Quantizer):
if _is_annotated([unary_node, conv_node]):
continue
self._annotate_conv_node_helper(conv_node, False, quantization_config)
unary_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
@ -522,15 +517,38 @@ class X86InductorQuantizer(Quantizer):
assert isinstance(input_node, Node)
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
maxpool_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
getitem_node.meta[
"quantization_annotation"
] = _X86InductorQuantizationAnnotation(
getitem_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def _annotate_cat(
self, node: Node, quantization_config: QuantizationConfig
) -> None:
cat_node = node
input_nodes = cat_node.args[0]
assert isinstance(input_nodes, Sequence)
first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(cat_node, Node)
input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, cat_node)
)
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
# There has the case of cat same nodes: torch.cat([input0, input0], 1)
assert isinstance(input_node, Node)
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
_is_output_of_quantized_pattern=True,
)
@ -559,9 +577,44 @@ class X86InductorQuantizer(Quantizer):
return
self._annotate_maxpool2d(node, quantization_config)
return
elif node.target is torch.ops.aten.cat.default:
input_nodes_to_check = node.all_input_nodes
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
return
self._annotate_cat(node, quantization_config)
else:
# TODO <leslie>: Enable recipes for more single quantizable op such as view and relu.
pass
input_node = node.all_input_nodes[0]
if not is_all_inputs_connected_to_quantized_op(
[
input_node,
]
):
return
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
_is_output_of_quantized_pattern=True,
)
return
def _annotate_output_share_observer_as_input(
self, input_node: Node, source_node: Node
):
source_node_quantization_annotation = (
source_node.meta[QUANT_ANNOTATION_KEY]
if QUANT_ANNOTATION_KEY in source_node.meta
else None
)
if (
source_node_quantization_annotation
and source_node_quantization_annotation._is_output_of_quantized_pattern
):
edge_or_node = (input_node, source_node)
source_node_quantization_annotation.output_qspec = SharedQuantizationSpec(
edge_or_node
)
return
def _annotate_output_for_int8_in_int8_out_pattern(
@ -572,6 +625,7 @@ class X86InductorQuantizer(Quantizer):
Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
"""
edge_or_node: Tuple[Node, Node]
if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])):
if node.target == torch.ops.aten.max_pool2d_with_indices.default:
maxpool_node = node
@ -583,8 +637,8 @@ class X86InductorQuantizer(Quantizer):
return
# Get the quantization_annotation from getitem_node
getitem_quantization_annotation = (
getitem_node.meta["quantization_annotation"]
if "quantization_annotation" in getitem_node.meta
getitem_node.meta[QUANT_ANNOTATION_KEY]
if QUANT_ANNOTATION_KEY in getitem_node.meta
else None
)
if (
@ -595,13 +649,13 @@ class X86InductorQuantizer(Quantizer):
input_act = maxpool_node.args[0]
assert isinstance(input_act, Node)
assert isinstance(maxpool_node, Node)
edge_or_node: Tuple[Node, Node] = (input_act, maxpool_node)
getitem_node.meta[
"quantization_annotation"
].output_qspec = SharedQuantizationSpec(edge_or_node)
edge_or_node = (input_act, maxpool_node)
getitem_quantization_annotation.output_qspec = (
SharedQuantizationSpec(edge_or_node)
)
else:
# TODO <leslie>: Enable recipes for more int8_in_int8_out_ops
pass
input_node = node.all_input_nodes[0]
self._annotate_output_share_observer_as_input(input_node, node)
return
def validate(self, model: torch.fx.GraphModule) -> None: