[quant][pt2e] store scale/zero_point as tensor attributes to support serialization (#105894)

Summary:
Currently scale/zero_point for per tensor quant is stored as burnt in literals, this means these values can't be serialized in state_dict, this
PR changes them to buffers/Tensors so that they can be serialized

Test Plan:
python test/test_quantization.py TestQuantizePT2E

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D47770963](https://our.internmc.facebook.com/intern/diff/D47770963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105894
Approved by: https://github.com/kimishpatel
This commit is contained in:
Jerry Zhang 2023-07-25 13:14:02 -07:00 committed by PyTorch MergeBot
parent 841b4acf1e
commit 3ca71ed735
6 changed files with 210 additions and 181 deletions

View File

@ -6084,8 +6084,8 @@ class TestQuantizeFx(QuantizationTestCase):
m_ref = convert_to_reference_fx(m_ref)
m = _convert_to_reference_decomposed_fx(m)
expected_occurrence = {
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
}
self.checkGraphModuleNodes(
m,
@ -6145,8 +6145,8 @@ class TestQuantizeFx(QuantizationTestCase):
m = _convert_to_reference_decomposed_fx(m)
expected_occurrence = {
# for input and output activations
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
# for weight
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,

View File

@ -71,6 +71,9 @@ from torch.ao.quantization import (
default_dynamic_qconfig,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import (
TemporaryFileName,
)
# TODO: Move to common utils or use existing quant utils to fetch model instances
class TestHelperModules:
@ -221,8 +224,16 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
expected_node_list=None,
check_against_fx_quant=False,
fx_qconfig_mapping=None,
fx_node_occurrence=None,
export_with_dynamic_shape=False,
check_save_load=False,
example_inputs_for_load=None,
):
"""
example_inputs_for_load: is example_inputs used for loading model, it should be
different from example_inputs so that we can actually test the save/load functionality,
we also generate this from example_inputs if example_inputs is a single element tuple.
"""
m_eager = model.eval()
# program capture
@ -267,14 +278,45 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
aten_graph=True,
tracing_mode="symbolic" if export_with_dynamic_shape else "real",
)
node_occurrence = {}
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
if k in expected_node_occurrence:
node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
if fx_node_occurrence is None:
fx_node_occurrence = {}
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
if k in expected_node_occurrence:
fx_node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
else:
fx_node_occurrence = {
ns.call_function(k): v for k, v in fx_node_occurrence.items()
}
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=fx_node_occurrence)
fx_quant_output = m_fx(*example_inputs)
self.assertEqual(fx_quant_output, pt2_quant_output)
if check_save_load:
if example_inputs_for_load is None:
assert isinstance(example_inputs, tuple) and \
len(example_inputs) == 1 and \
isinstance(example_inputs[0], torch.Tensor)
example_inputs_for_load = (example_inputs[0] * 2,)
with TemporaryFileName() as fname, torchdynamo.config.patch(dynamic_shapes=export_with_dynamic_shape):
torch.save(m.state_dict(), fname)
ref_result = m(*example_inputs)
m_loaded = copy.deepcopy(m_eager)
m_loaded, guards = torchdynamo.export(
m_loaded,
*copy.deepcopy(example_inputs_for_load),
aten_graph=True,
tracing_mode="symbolic" if export_with_dynamic_shape else "real",
)
m_loaded = prepare_pt2e(m_loaded, quantizer)
m_loaded(*example_inputs_for_load)
m_loaded = convert_pt2e(m_loaded)
m_loaded.load_state_dict(torch.load(fname))
loaded_result = m_loaded(*example_inputs)
self.assertTrue(torch.equal(ref_result, loaded_result))
def _verify_symmetric_qnnpack_qat_numerics(
self,
model: torch.nn.Module,
@ -583,14 +625,14 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
]
self._test_quantizer(
TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
@ -598,6 +640,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
BackendAQuantizer(),
node_occurrence,
node_list,
check_save_load=True,
)
def test_wo_annotate_conv_output_quantizer(self):
@ -666,12 +709,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# Ensure the conv has no observer inserted at output
node_occurrence = {
# two for input of conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.convolution.default),
]
self.checkGraphModuleNodes(
@ -770,15 +813,15 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# two for input of conv
# one for input of maxpool
# one for output of maxpool
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 4,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 4,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 4,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.max_pool2d_with_indices.default),
]
self.checkGraphModuleNodes(
@ -873,25 +916,25 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
# input, weight, bias, output for the conv
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
): 4,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
): 4,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
),
]
self.checkGraphModuleNodes(
@ -949,40 +992,26 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = convert_pt2e(m)
fixed_scale = 1.0 / 256.0
fixed_zero_point = 0
for n in m.graph.nodes:
if n.op == "call_function":
if (
n.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
):
scale_0 = n.args[1]
zero_point_0 = n.args[2]
if (
n.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
scale_1 = n.args[1]
zero_point_1 = n.args[2]
self.assertEqual(scale_0, fixed_scale)
self.assertEqual(zero_point_0, fixed_zero_point)
self.assertEqual(scale_1, fixed_scale)
self.assertEqual(zero_point_1, fixed_zero_point)
self.assertEqual(m._scale_0, fixed_scale)
self.assertEqual(m._zero_point_0, fixed_zero_point)
self.assertEqual(m._scale_1, fixed_scale)
self.assertEqual(m._zero_point_1, fixed_zero_point)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
): 2,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
),
ns.call_function(torch.ops.aten.sigmoid.default),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
),
]
self.checkGraphModuleNodes(
@ -996,16 +1025,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
]
self._test_quantizer(
TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
@ -1013,6 +1042,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer,
node_occurrence,
node_list,
check_save_load=True,
)
def test_xnnpack_quantizer_linear(self):
@ -1027,8 +1057,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs_4d = (torch.randn(9, 10, 11, 8),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
@ -1043,6 +1073,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
check_save_load=True,
)
def test_xnnpack_quantizer_conv_linear_no_permute(self):
@ -1051,8 +1082,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer.set_global(operator_config)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
@ -1068,6 +1099,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
check_save_load=True,
)
def test_xnnpack_quantizer_conv_linear(self):
@ -1078,8 +1110,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# Test with 2d inputs
example_inputs = (torch.randn(2, 3, 4, 4),)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
@ -1093,6 +1125,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
check_save_load=True,
)
def test_xnnpack_quantizer_linear_with_dynamic_shape(self):
@ -1105,8 +1138,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs_3d = (torch.randn(9, 10, 8),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
@ -1121,6 +1154,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
True,
qconfig_mapping,
export_with_dynamic_shape=True,
check_save_load=True,
)
def test_xnnpack_quantizer_obs_sharing_ops(self):
@ -1131,28 +1165,28 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.mean.dim,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.hardtanh.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.mean.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list, check_save_load=True,)
def test_propagate_annotation(self):
quantizer = XNNPACKQuantizer()
@ -1180,10 +1214,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
): 5,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
): 5,
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
@ -1235,6 +1269,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
check_save_load=True,
)
def test_xnnpack_quantizer_dynamic_linear_with_conv(self):
@ -1247,6 +1282,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
}
# the weight q/dq is traced from reference quantized linear module as constant, so
# we see quantize_per_tensor.default instead of quantize_per_tensor.tensor for weight
fx_node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
@ -1275,6 +1316,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
fx_node_occurrence=fx_node_occurrence,
check_save_load=True,
)
def test_composable_quantizer_linear_conv(self):
@ -1295,10 +1338,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m_eager = TestHelperModules.ConvLinearWPermute().eval()
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
# torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
# torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
@ -1331,6 +1374,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
False,
qconfig_mapping,
check_save_load=True,
)
def test_composable_quantizer_throw(self):
@ -1426,6 +1470,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_list,
True,
qconfig_mapping,
check_save_load=True,
example_inputs_for_load=(torch.zeros_like(indices),)
)
def test_embedding_conv_linear_quantization(self):
@ -1502,10 +1548,15 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
# permute is not handled the same way in fx, so it will have 6 q/dqs
fx_node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 6,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 6,
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
@ -1517,6 +1568,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
[],
True,
qconfig_mapping,
fx_node_occurrence=fx_node_occurrence,
check_save_load=True,
example_inputs_for_load=(torch.zeros_like(indices),)
)
def test_prepare_qat_conv_bn_fusion(self):
@ -1811,10 +1865,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
}
non_ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
): 3,
}
self._test_representation(

View File

@ -76,16 +76,16 @@ class TestQuantizePT2EFX(QuantizationTestCase):
# first conv is quantized, second conv is not quantized
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.convolution.default),
]
self.checkGraphModuleNodes(
@ -129,16 +129,16 @@ class TestQuantizePT2EFX(QuantizationTestCase):
# conv is quantized, linear is not quantized
node_occurrence = {
# two for input and weight of the conv, one for output for the conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
ns.call_function(torch.ops.aten.addmm.default),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
@ -212,8 +212,8 @@ class TestQuantizePT2EFX(QuantizationTestCase):
m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
expected_occurrence = {
# for input and output activations
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
# weight is per channel quantized
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,

View File

@ -202,17 +202,17 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
)
node_occurrence = {
# one for input and weight of the conv, one for output for the conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
self._test_quantizer(
m,
@ -239,18 +239,18 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
)
node_occurrence = {
# one for input and weight of the conv, one for output for the relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
self._test_quantizer(
m,
@ -280,8 +280,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight of the conv
# one for output for the add
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
@ -292,18 +292,18 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for output for the add
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.aten.add.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
self._test_quantizer(
m,
@ -335,8 +335,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for input and weight of the conv
# one for output for the relu
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
@ -347,18 +347,18 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# one for output for the relu
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.aten.add.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
self._test_quantizer(
m,
@ -378,23 +378,23 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.aten.convolution.default,
torch.ops.aten.convolution.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.relu.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
self._test_quantizer(
m,

View File

@ -157,8 +157,8 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
"_dtype_": dtype_
}
else:
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
scale = float(scale)
zero_point = int(zero_point)
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
@ -179,7 +179,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))):
if key in ['_scale_', '_zero_point_']:
# For scale and zero_point values we register them as buffers in the root module.
# However, note that when the values are not tensors, as in the case of
# per_tensor quantization, they will be treated as literals.
@ -379,6 +379,7 @@ def _replace_observer_with_quantize_dequantize_node(
# be registered as an attribute in qparams dict itself
if key in ['_scale_', '_zero_point_']:
# For scale and zero_point values we register them as buffers in the root module.
# this is needed to support serialization and deserialization for scale/zero_point
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(
model, graph, module_path + prefix + key, value_or_node)

View File

@ -54,9 +54,8 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
in the pattern.
"""
kwargs = {}
if is_per_channel:
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
return kwargs
@ -160,7 +159,7 @@ def _get_input_output_quantized_filter():
if pattern_node.op == "placeholder":
if (
original_node.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
== torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
):
input_dq_node = original_node
# output node is not a separate node in the list of nodes seen in the matçh
@ -172,7 +171,7 @@ def _get_input_output_quantized_filter():
output_node = list(original_node.users.keys())[0]
if (
output_node.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
== torch.ops.quantized_decomposed.quantize_per_tensor.tensor
):
output_q_node = original_node
return (input_dq_node is not None) and (output_q_node is not None)
@ -217,19 +216,19 @@ def _get_quantized_qat_conv2d_bn_pattern(
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
if is_per_channel:
scaled_weight = torch.ops.quantized_decomposed.quantize_per_channel(
scaled_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
weight_quant_min, weight_quant_max, torch.int8,
)
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_channel(
scaled_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
weight_quant_min, weight_quant_max, torch.int8,
)
else:
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
scaled_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
)
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
scaled_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
)
if has_bias:
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
@ -274,19 +273,19 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(
) -> torch.Tensor:
if is_per_channel:
conv_weight = torch.ops.quantized_decomposed.quantize_per_channel(
conv_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
weight_quant_min, weight_quant_max, torch.int8,
)
conv_weight = torch.ops.quantized_decomposed.dequantize_per_channel(
conv_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
weight_quant_min, weight_quant_max, torch.int8,
)
else:
conv_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
conv_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
)
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
conv_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
)
if has_bias:
x = F.conv2d(x, conv_weight, kwargs["conv_bias"])
@ -597,7 +596,7 @@ def _duplicate_dequantize_node(m: GraphModule):
the dequantize node has users outside the matched portion of the graph.
Instead, we match [dequantize_1 - a], which is safe.
"""
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
for n in m.graph.nodes:
if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
continue
@ -615,7 +614,7 @@ def _remove_extra_dequantize(m: GraphModule):
that can be shared across all the uses. This should be seen as the "reverse"
of `_duplicate_dequantize_node`.
"""
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
for n in m.graph.nodes:
dq_users = [user for user in n.users if user.op == "call_function" and user.target == dq_op]
if len(dq_users) > 1:
@ -694,31 +693,6 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
conv_bias = conv_node.args[2]
assert conv_bias is None or isinstance(conv_bias, Node)
(weight_q_node, weight_dq_node) = _get_fused_convbn_q_dq_nodes(r.replacements)
original_weight_q_node = None
original_weight_dq_node = None
for pattern_node, original_node in r.nodes_map.items():
if pattern_node.op == 'placeholder':
continue
if (
original_node.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
):
assert original_weight_q_node is None
original_weight_q_node = original_node
weight_q_node.args = (
weight_q_node.args[:1] + original_weight_q_node.args[1:]
)
if (
original_node.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
):
assert original_weight_dq_node is None
original_weight_dq_node = original_node
weight_dq_node.args = (
weight_dq_node.args[:1] + original_weight_dq_node.args[1:]
)
# fold bn weights into conv
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)