diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index d526ea3a750..5b86ebe845b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -1755,6 +1755,40 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): with self.assertRaises(NotImplementedError): m.train() + def test_reentrant(self): + """Test we can safely call quantization apis multiple times""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_qat=True)) + m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs) + m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) + m(*example_inputs) + m.conv_bn_relu = convert_pt2e(m.conv_bn_relu, fold_quantize=True) + + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=False)) + m = capture_pre_autograd_graph(m, example_inputs) + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m, fold_quantize=True) + + node_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4, + # one for weight + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 5, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, + } + node_list = [ + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function(torch.ops.aten.relu.default), + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default), + ns.call_function(torch.ops.aten.linear.default), + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default), + ] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) @skipIfNoQNNPACK class TestQuantizePT2EOps(QuantizationTestCase): diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index c28ffcc64ab..5894c30d7f1 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -2591,3 +2591,16 @@ class TestHelperModules: x = x * y x *= y return x + + class ConvBnReLU2dAndLinearReLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True) + self.linear = torch.nn.Linear(3, 8, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv_bn_relu(x) + permute_out = torch.permute(x, (0, 2, 3, 1)) + linear_out = self.linear(permute_out) + return linear_out