mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant][PT2E] Add cat and avg_pool2d recipe into x86InductorQuantizer (#106836)
**Summary** Add `cat` and `avg_pool2d` quantization recipe as input output share observer into `x86InductorQuantizer`. **Test Plan** ``` clear && python -m pytest test_x86inductor_quantizer.py -k test_cat_recipe clear && python -m pytest test_x86inductor_quantizer.py -k test_cat_recipe_same_inputs clear && python -m pytest test_x86inductor_quantizer.py -k test_cat_recipe_single_input clear && python -m pytest test_x86inductor_quantizer.py -k test_avg_pool2d_recipe ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/106836 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
15d4dedbbf
commit
1147a28b0b
|
|
@ -139,7 +139,6 @@ class TestHelperModules:
|
|||
x = self.pool(x)
|
||||
return torch.pow(x, 2)
|
||||
|
||||
|
||||
class SerialsConv2dAddReLUModule(torch.nn.Module):
|
||||
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
|
||||
"""
|
||||
|
|
@ -166,6 +165,55 @@ class TestHelperModules:
|
|||
res2 = self.relu2(self.conv4(res1) + res1)
|
||||
return res2
|
||||
|
||||
class Conv2dCatMaxpool2d(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
|
||||
self.conv2 = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.maxpool = torch.nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.conv3 = torch.nn.Conv2d(32, 32, 7, bias=True, stride=2, padding=3, dilation=1)
|
||||
|
||||
def forward(self, x):
|
||||
temp1 = self.relu(self.conv(x))
|
||||
temp2 = self.conv2(x + 1)
|
||||
temp3 = torch.cat((temp1, temp2), 1)
|
||||
temp4 = self.maxpool(temp3)
|
||||
temp5 = self.conv3(temp4)
|
||||
return temp5
|
||||
|
||||
class Conv2dAvgPool2d(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
|
||||
self.avgpool = torch.nn.AvgPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
temp1 = self.avgpool(self.conv(x))
|
||||
return temp1
|
||||
|
||||
class Conv2dCatSameInputs(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
temp1 = self.relu(self.conv(x))
|
||||
temp3 = torch.cat((temp1, temp1), 1)
|
||||
return temp3
|
||||
|
||||
class Conv2dCatSingleInput(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 7, bias=True, stride=2, padding=3, dilation=1)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
temp1 = self.relu(self.conv(x))
|
||||
temp3 = torch.cat((temp1,), 1)
|
||||
return temp3
|
||||
|
||||
class X86InductorQuantTestCase(QuantizationTestCase):
|
||||
def _test_quantizer(
|
||||
self,
|
||||
|
|
@ -479,3 +527,244 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
|
||||
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
|
||||
self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_cat_recipe(self):
|
||||
r"""
|
||||
Test pattern: conv -> cat -> maxpool2d
|
||||
Since cat, maxpool is a int8_in_int8_out_op, the inputs and outputs should with same observer.
|
||||
"""
|
||||
m = TestHelperModules.Conv2dCatMaxpool2d().eval()
|
||||
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
example_inputs = (x,)
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 7,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.cat.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.max_pool2d_with_indices.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
_, prepare_model, _ = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
# Check Cat/Maxpool2d has share observer at input and output
|
||||
for node in prepare_model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.cat.default
|
||||
):
|
||||
cat_act_obs0 = getattr(
|
||||
prepare_model, node.all_input_nodes[0].target
|
||||
)
|
||||
cat_act_obs1 = getattr(
|
||||
prepare_model, node.all_input_nodes[1].target
|
||||
)
|
||||
cat_out_obs = getattr(
|
||||
prepare_model, list(node.users)[0].target
|
||||
)
|
||||
elif (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops.aten.max_pool2d_with_indices.default
|
||||
):
|
||||
maxpool_node = node
|
||||
input_obs_of_maxpool = getattr(
|
||||
prepare_model, maxpool_node.args[0].target
|
||||
)
|
||||
elif node.op == "call_function" and node.target is operator.getitem:
|
||||
output_obs_of_maxpool = getattr(
|
||||
prepare_model, list(node.users)[0].target
|
||||
)
|
||||
self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
|
||||
self.assertTrue(isinstance(cat_act_obs1, ObserverBase))
|
||||
self.assertTrue(isinstance(cat_out_obs, ObserverBase))
|
||||
self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
|
||||
self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
|
||||
self.assertTrue(cat_act_obs0 is cat_act_obs1)
|
||||
self.assertTrue(cat_act_obs0 is cat_out_obs)
|
||||
self.assertTrue(cat_out_obs is input_obs_of_maxpool)
|
||||
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_cat_recipe_same_inputs(self):
|
||||
r"""
|
||||
Test pattern: conv -> cat([input0, input0])
|
||||
Since cat has 2 input node of same tensor, they should also be with same observer.
|
||||
"""
|
||||
m = TestHelperModules.Conv2dCatSameInputs().eval()
|
||||
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
example_inputs = (x,)
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.cat.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
_, prepare_model, _ = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
# Check Cat has share observer at input and output
|
||||
for node in prepare_model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.cat.default
|
||||
):
|
||||
cat_act_obs0 = getattr(
|
||||
prepare_model, node.args[0][0].target
|
||||
)
|
||||
cat_act_obs1 = getattr(
|
||||
prepare_model, node.args[0][1].target
|
||||
)
|
||||
cat_out_obs = getattr(
|
||||
prepare_model, list(node.users)[0].target
|
||||
)
|
||||
self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
|
||||
self.assertTrue(isinstance(cat_act_obs1, ObserverBase))
|
||||
self.assertTrue(isinstance(cat_out_obs, ObserverBase))
|
||||
self.assertTrue(cat_act_obs0 is cat_act_obs1)
|
||||
self.assertTrue(cat_act_obs0 is cat_out_obs)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_cat_recipe_single_input(self):
|
||||
r"""
|
||||
Test pattern: conv -> cat([input0,])
|
||||
Since cat has 1 input node, they should also be with same observer.
|
||||
"""
|
||||
m = TestHelperModules.Conv2dCatSingleInput().eval()
|
||||
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
example_inputs = (x,)
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.cat.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
_, prepare_model, _ = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
# Check Cat has share observer at input and output
|
||||
for node in prepare_model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.cat.default
|
||||
):
|
||||
cat_act_obs0 = getattr(
|
||||
prepare_model, node.args[0][0].target
|
||||
)
|
||||
cat_out_obs = getattr(
|
||||
prepare_model, list(node.users)[0].target
|
||||
)
|
||||
self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
|
||||
self.assertTrue(isinstance(cat_out_obs, ObserverBase))
|
||||
self.assertTrue(cat_act_obs0 is cat_out_obs)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_avg_pool2d_recipe(self):
|
||||
r"""
|
||||
Test pattern: conv -> AvgPool2d
|
||||
Since AvgPool2d is a int8_in_int8_out_op, the inputs and outputs should with same observer.
|
||||
"""
|
||||
m = TestHelperModules.Conv2dAvgPool2d().eval()
|
||||
x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
example_inputs = (x,)
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.avg_pool2d.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
_, prepare_model, _ = self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
for node in prepare_model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops.aten.avg_pool2d.default
|
||||
):
|
||||
avgpool_node = node
|
||||
input_obs_of_avgpool = getattr(
|
||||
prepare_model, avgpool_node.args[0].target
|
||||
)
|
||||
output_obs_of_avgpool = getattr(
|
||||
prepare_model, list(avgpool_node.users)[0].target
|
||||
)
|
||||
elif (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops.aten.convolution.default
|
||||
):
|
||||
conv_node = node
|
||||
output_obs_of_conv = getattr(prepare_model, list(conv_node.users)[0].target)
|
||||
self.assertTrue(isinstance(input_obs_of_avgpool, ObserverBase))
|
||||
self.assertTrue(isinstance(output_obs_of_avgpool, ObserverBase))
|
||||
self.assertTrue(isinstance(output_obs_of_conv, ObserverBase))
|
||||
self.assertTrue(input_obs_of_avgpool is output_obs_of_avgpool)
|
||||
self.assertTrue(input_obs_of_avgpool is output_obs_of_conv)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import functools
|
|||
import itertools
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -54,6 +54,8 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
|
|||
# Ops support int8 data type and excludes ops like conv, linear.
|
||||
quantizable_ops_pt2e: Set = {
|
||||
torch.ops.aten.max_pool2d_with_indices.default,
|
||||
torch.ops.aten.cat.default,
|
||||
torch.ops.aten.avg_pool2d.default,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -62,16 +64,21 @@ quantizable_ops_pt2e: Set = {
|
|||
# 2. Ops don't support int8 in and fp32 out.
|
||||
int8_in_int8_out_ops_pt2e: Set = {
|
||||
torch.ops.aten.max_pool2d_with_indices.default,
|
||||
torch.ops.aten.cat.default,
|
||||
torch.ops.aten.avg_pool2d.default,
|
||||
}
|
||||
|
||||
|
||||
QUANT_ANNOTATION_KEY = "quantization_annotation"
|
||||
|
||||
|
||||
def _is_node_annotated(_node):
|
||||
"""
|
||||
return True if the node is annotated, otherwise return False
|
||||
"""
|
||||
return (
|
||||
"quantization_annotation" in _node.meta
|
||||
and _node.meta["quantization_annotation"]._annotated
|
||||
QUANT_ANNOTATION_KEY in _node.meta
|
||||
and _node.meta[QUANT_ANNOTATION_KEY]._annotated
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -101,7 +108,7 @@ def _is_quantized_op_pt2e(node: torch.fx.Node):
|
|||
if not _is_any_annotated([node]):
|
||||
# The node has not been annotated, directly return False
|
||||
return False
|
||||
quantization_annotation = node.meta.get("quantization_annotation", None)
|
||||
quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
|
||||
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
|
||||
return quantization_annotation._is_output_of_quantized_pattern
|
||||
|
||||
|
|
@ -258,9 +265,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
if isinstance(bias_node, Node):
|
||||
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
|
||||
if annotate_output:
|
||||
conv_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config),
|
||||
|
|
@ -268,9 +273,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
else:
|
||||
conv_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
|
|
@ -398,15 +401,11 @@ class X86InductorQuantizer(Quantizer):
|
|||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||
quantization_config
|
||||
)
|
||||
binary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
unary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
|
|
@ -447,9 +446,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||
quantization_config
|
||||
)
|
||||
binary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
|
|
@ -476,9 +473,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
if _is_annotated([unary_node, conv_node]):
|
||||
continue
|
||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||
unary_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
|
|
@ -522,15 +517,38 @@ class X86InductorQuantizer(Quantizer):
|
|||
assert isinstance(input_node, Node)
|
||||
input_qspec_map = {}
|
||||
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||
maxpool_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
getitem_node.meta[
|
||||
"quantization_annotation"
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
getitem_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotate_cat(
|
||||
self, node: Node, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
cat_node = node
|
||||
input_nodes = cat_node.args[0]
|
||||
assert isinstance(input_nodes, Sequence)
|
||||
first_input_node = input_nodes[0]
|
||||
input_qspec_map = {}
|
||||
assert isinstance(first_input_node, Node)
|
||||
assert isinstance(cat_node, Node)
|
||||
input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
|
||||
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
|
||||
(first_input_node, cat_node)
|
||||
)
|
||||
|
||||
for input_node in input_nodes[1:]:
|
||||
if input_node not in input_qspec_map:
|
||||
# There has the case of cat same nodes: torch.cat([input0, input0], 1)
|
||||
assert isinstance(input_node, Node)
|
||||
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
|
||||
|
||||
cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
|
@ -559,9 +577,44 @@ class X86InductorQuantizer(Quantizer):
|
|||
return
|
||||
self._annotate_maxpool2d(node, quantization_config)
|
||||
return
|
||||
elif node.target is torch.ops.aten.cat.default:
|
||||
input_nodes_to_check = node.all_input_nodes
|
||||
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
|
||||
return
|
||||
self._annotate_cat(node, quantization_config)
|
||||
else:
|
||||
# TODO <leslie>: Enable recipes for more single quantizable op such as view and relu.
|
||||
pass
|
||||
input_node = node.all_input_nodes[0]
|
||||
if not is_all_inputs_connected_to_quantized_op(
|
||||
[
|
||||
input_node,
|
||||
]
|
||||
):
|
||||
return
|
||||
input_qspec_map = {}
|
||||
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
||||
node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
return
|
||||
|
||||
def _annotate_output_share_observer_as_input(
|
||||
self, input_node: Node, source_node: Node
|
||||
):
|
||||
source_node_quantization_annotation = (
|
||||
source_node.meta[QUANT_ANNOTATION_KEY]
|
||||
if QUANT_ANNOTATION_KEY in source_node.meta
|
||||
else None
|
||||
)
|
||||
if (
|
||||
source_node_quantization_annotation
|
||||
and source_node_quantization_annotation._is_output_of_quantized_pattern
|
||||
):
|
||||
edge_or_node = (input_node, source_node)
|
||||
source_node_quantization_annotation.output_qspec = SharedQuantizationSpec(
|
||||
edge_or_node
|
||||
)
|
||||
return
|
||||
|
||||
def _annotate_output_for_int8_in_int8_out_pattern(
|
||||
|
|
@ -572,6 +625,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
|
||||
90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
|
||||
"""
|
||||
edge_or_node: Tuple[Node, Node]
|
||||
if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])):
|
||||
if node.target == torch.ops.aten.max_pool2d_with_indices.default:
|
||||
maxpool_node = node
|
||||
|
|
@ -583,8 +637,8 @@ class X86InductorQuantizer(Quantizer):
|
|||
return
|
||||
# Get the quantization_annotation from getitem_node
|
||||
getitem_quantization_annotation = (
|
||||
getitem_node.meta["quantization_annotation"]
|
||||
if "quantization_annotation" in getitem_node.meta
|
||||
getitem_node.meta[QUANT_ANNOTATION_KEY]
|
||||
if QUANT_ANNOTATION_KEY in getitem_node.meta
|
||||
else None
|
||||
)
|
||||
if (
|
||||
|
|
@ -595,13 +649,13 @@ class X86InductorQuantizer(Quantizer):
|
|||
input_act = maxpool_node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
assert isinstance(maxpool_node, Node)
|
||||
edge_or_node: Tuple[Node, Node] = (input_act, maxpool_node)
|
||||
getitem_node.meta[
|
||||
"quantization_annotation"
|
||||
].output_qspec = SharedQuantizationSpec(edge_or_node)
|
||||
edge_or_node = (input_act, maxpool_node)
|
||||
getitem_quantization_annotation.output_qspec = (
|
||||
SharedQuantizationSpec(edge_or_node)
|
||||
)
|
||||
else:
|
||||
# TODO <leslie>: Enable recipes for more int8_in_int8_out_ops
|
||||
pass
|
||||
input_node = node.all_input_nodes[0]
|
||||
self._annotate_output_share_observer_as_input(input_node, node)
|
||||
return
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user