mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Add quant API re-entrant test (#110125)
Summary: Add the test to make sure we can call the quantize API multiple times Test Plan: python test/test_quantization.py TestQuantizePT2E.test_reentrant Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/110125 Approved by: https://github.com/kimishpatel ghstack dependencies: #110097
This commit is contained in:
parent
bbb95878e9
commit
3de42995e4
|
|
@ -1755,6 +1755,40 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
|
def test_reentrant(self):
|
||||||
|
"""Test we can safely call quantization apis multiple times"""
|
||||||
|
m = TestHelperModules.ConvBnReLU2dAndLinearReLU()
|
||||||
|
example_inputs = (torch.randn(3, 3, 10, 10),)
|
||||||
|
|
||||||
|
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_qat=True))
|
||||||
|
m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs)
|
||||||
|
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
|
||||||
|
m(*example_inputs)
|
||||||
|
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu, fold_quantize=True)
|
||||||
|
|
||||||
|
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=False))
|
||||||
|
m = capture_pre_autograd_graph(m, example_inputs)
|
||||||
|
m = prepare_pt2e(m, quantizer)
|
||||||
|
m = convert_pt2e(m, fold_quantize=True)
|
||||||
|
|
||||||
|
node_occurrence = {
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
|
||||||
|
# one for weight
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 5,
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
||||||
|
}
|
||||||
|
node_list = [
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||||
|
ns.call_function(torch.ops.aten.conv2d.default),
|
||||||
|
ns.call_function(torch.ops.aten.relu.default),
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||||
|
ns.call_function(torch.ops.aten.linear.default),
|
||||||
|
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
|
||||||
|
]
|
||||||
|
self.checkGraphModuleNodes(
|
||||||
|
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
|
||||||
|
)
|
||||||
|
|
||||||
@skipIfNoQNNPACK
|
@skipIfNoQNNPACK
|
||||||
class TestQuantizePT2EOps(QuantizationTestCase):
|
class TestQuantizePT2EOps(QuantizationTestCase):
|
||||||
|
|
|
||||||
|
|
@ -2591,3 +2591,16 @@ class TestHelperModules:
|
||||||
x = x * y
|
x = x * y
|
||||||
x *= y
|
x *= y
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
|
||||||
|
self.linear = torch.nn.Linear(3, 8, bias=False)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_bn_relu(x)
|
||||||
|
permute_out = torch.permute(x, (0, 2, 3, 1))
|
||||||
|
linear_out = self.linear(permute_out)
|
||||||
|
return linear_out
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user