[quant][pt2e] Add quant API re-entrant test (#110125)

Summary:
Add the test to make sure we can call the quantize API multiple times

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_reentrant

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110125
Approved by: https://github.com/kimishpatel
ghstack dependencies: #110097
This commit is contained in:
Jerry Zhang 2023-09-27 12:56:21 -07:00 committed by PyTorch MergeBot
parent bbb95878e9
commit 3de42995e4
2 changed files with 47 additions and 0 deletions

View File

@ -1755,6 +1755,40 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
with self.assertRaises(NotImplementedError):
m.train()
def test_reentrant(self):
"""Test we can safely call quantization apis multiple times"""
m = TestHelperModules.ConvBnReLU2dAndLinearReLU()
example_inputs = (torch.randn(3, 3, 10, 10),)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_qat=True))
m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs)
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
m(*example_inputs)
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu, fold_quantize=True)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=False))
m = capture_pre_autograd_graph(m, example_inputs)
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m, fold_quantize=True)
node_occurrence = {
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
# one for weight
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 5,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.aten.conv2d.default),
ns.call_function(torch.ops.aten.relu.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.aten.linear.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
]
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)
@skipIfNoQNNPACK
class TestQuantizePT2EOps(QuantizationTestCase):

View File

@ -2591,3 +2591,16 @@ class TestHelperModules:
x = x * y
x *= y
return x
class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
self.linear = torch.nn.Linear(3, 8, bias=False)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv_bn_relu(x)
permute_out = torch.permute(x, (0, 2, 3, 1))
linear_out = self.linear(permute_out)
return linear_out