[quant][pt2e] Add reference representation for quantized adaptive_avg_pool2d (#105709)

Summary:
Implementing reference representation for quantized ops we decided in https://docs.google.com/document/d/17h-OEtD4o_hoVuPqUFsdm5uo7psiNMY8ThN03F9ZZwg/edit#heading=h.ov8z39149wy8

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_representation_adaptive_avg_pool2d

Although right now it is not really testing things since there is some problem with dynamo export

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105709
Approved by: https://github.com/andrewor14
ghstack dependencies: #105708
This commit is contained in:
Jerry Zhang 2023-08-03 18:24:12 -07:00 committed by PyTorch MergeBot
parent 3e6da46aff
commit 2156f0434c
2 changed files with 63 additions and 0 deletions

View File

@ -157,6 +157,17 @@ class TestHelperModules:
x = self.pool(x)
return x
class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
return x
class ConvWithBNRelu(torch.nn.Module):
def __init__(self, relu, bn=True, bias=True):
super().__init__()
@ -1920,6 +1931,22 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
non_ref_node_occurrence={}
)
def test_representation_adaptive_avg_pool2d(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval()
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_representation(
m_eager,
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
def test_representation_quantize_dequantize(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -146,6 +146,38 @@ def _reference_quantized_max_pool2d(
out_i8 = out_fp32.to(torch.int8)
return out_i8
_QUANTIZED_ADAPTIVE_AVG_POOL2D_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_adaptive_avg_pool2d(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
output_size = (3, 3)
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
out_fp32 = torch.ops.aten.adaptive_avg_pool2d(x_fp32, output_size)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
return out_i8
def _reference_quantized_adaptive_avg_pool2d(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
output_size = (3, 3)
x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
x_i32 = x_i8.to(torch.int32)
out_i32 = torch.ops.aten.adaptive_avg_pool2d(x_i32, output_size)
out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
out_i8 = out_fp32.to(torch.int8)
return out_i8
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(1, dtype=torch.float),
@ -201,6 +233,10 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
(_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, _qdq_quantized_add_relu, _reference_quantized_add_relu, _DONT_REPLACE_LITERAL),
(_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, _qdq_quantized_add, _reference_quantized_add, _DONT_REPLACE_LITERAL),
(_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, _qdq_quantized_max_pool2d, _reference_quantized_max_pool2d, _REPLACE_LITERAL),
(_QUANTIZED_ADAPTIVE_AVG_POOL2D_EXAMPLE_INPUTS,
_qdq_quantized_adaptive_avg_pool2d,
_reference_quantized_adaptive_avg_pool2d,
_REPLACE_LITERAL),
(_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_quantize_per_tensor_int8,
_reference_quantize_per_tensor_int8,