mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Add reference representation for quantized adaptive_avg_pool2d (#105709)
Summary: Implementing reference representation for quantized ops we decided in https://docs.google.com/document/d/17h-OEtD4o_hoVuPqUFsdm5uo7psiNMY8ThN03F9ZZwg/edit#heading=h.ov8z39149wy8 Test Plan: python test/test_quantization.py TestQuantizePT2E.test_representation_adaptive_avg_pool2d Although right now it is not really testing things since there is some problem with dynamo export Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/105709 Approved by: https://github.com/andrewor14 ghstack dependencies: #105708
This commit is contained in:
parent
3e6da46aff
commit
2156f0434c
|
|
@ -157,6 +157,17 @@ class TestHelperModules:
|
|||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
return x
|
||||
|
||||
class ConvWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, bn=True, bias=True):
|
||||
super().__init__()
|
||||
|
|
@ -1920,6 +1931,22 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_adaptive_avg_pool2d(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
|
||||
self._test_representation(
|
||||
m_eager,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_quantize_dequantize(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -146,6 +146,38 @@ def _reference_quantized_max_pool2d(
|
|||
out_i8 = out_fp32.to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
_QUANTIZED_ADAPTIVE_AVG_POOL2D_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
def _qdq_quantized_adaptive_avg_pool2d(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
|
||||
output_size = (3, 3)
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
||||
out_fp32 = torch.ops.aten.adaptive_avg_pool2d(x_fp32, output_size)
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
|
||||
return out_i8
|
||||
|
||||
def _reference_quantized_adaptive_avg_pool2d(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
|
||||
output_size = (3, 3)
|
||||
x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
|
||||
x_i32 = x_i8.to(torch.int32)
|
||||
out_i32 = torch.ops.aten.adaptive_avg_pool2d(x_i32, output_size)
|
||||
out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
|
||||
out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
|
||||
out_i8 = out_fp32.to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
|
|
@ -201,6 +233,10 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
|
|||
(_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, _qdq_quantized_add_relu, _reference_quantized_add_relu, _DONT_REPLACE_LITERAL),
|
||||
(_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, _qdq_quantized_add, _reference_quantized_add, _DONT_REPLACE_LITERAL),
|
||||
(_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, _qdq_quantized_max_pool2d, _reference_quantized_max_pool2d, _REPLACE_LITERAL),
|
||||
(_QUANTIZED_ADAPTIVE_AVG_POOL2D_EXAMPLE_INPUTS,
|
||||
_qdq_quantized_adaptive_avg_pool2d,
|
||||
_reference_quantized_adaptive_avg_pool2d,
|
||||
_REPLACE_LITERAL),
|
||||
(_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_quantize_per_tensor_int8,
|
||||
_reference_quantize_per_tensor_int8,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user