mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] store scale/zero_point as tensor attributes to support serialization (#105894)
Summary: Currently scale/zero_point for per tensor quant is stored as burnt in literals, this means these values can't be serialized in state_dict, this PR changes them to buffers/Tensors so that they can be serialized Test Plan: python test/test_quantization.py TestQuantizePT2E Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D47770963](https://our.internmc.facebook.com/intern/diff/D47770963) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105894 Approved by: https://github.com/kimishpatel
This commit is contained in:
parent
841b4acf1e
commit
3ca71ed735
|
|
@ -6084,8 +6084,8 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m_ref = convert_to_reference_fx(m_ref)
|
||||
m = _convert_to_reference_decomposed_fx(m)
|
||||
expected_occurrence = {
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
|
||||
}
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
|
|
@ -6145,8 +6145,8 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m = _convert_to_reference_decomposed_fx(m)
|
||||
expected_occurrence = {
|
||||
# for input and output activations
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
|
||||
# for weight
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
||||
|
|
|
|||
|
|
@ -71,6 +71,9 @@ from torch.ao.quantization import (
|
|||
default_dynamic_qconfig,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import (
|
||||
TemporaryFileName,
|
||||
)
|
||||
|
||||
# TODO: Move to common utils or use existing quant utils to fetch model instances
|
||||
class TestHelperModules:
|
||||
|
|
@ -221,8 +224,16 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
expected_node_list=None,
|
||||
check_against_fx_quant=False,
|
||||
fx_qconfig_mapping=None,
|
||||
fx_node_occurrence=None,
|
||||
export_with_dynamic_shape=False,
|
||||
check_save_load=False,
|
||||
example_inputs_for_load=None,
|
||||
):
|
||||
"""
|
||||
example_inputs_for_load: is example_inputs used for loading model, it should be
|
||||
different from example_inputs so that we can actually test the save/load functionality,
|
||||
we also generate this from example_inputs if example_inputs is a single element tuple.
|
||||
"""
|
||||
m_eager = model.eval()
|
||||
|
||||
# program capture
|
||||
|
|
@ -267,14 +278,45 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
aten_graph=True,
|
||||
tracing_mode="symbolic" if export_with_dynamic_shape else "real",
|
||||
)
|
||||
node_occurrence = {}
|
||||
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
|
||||
if k in expected_node_occurrence:
|
||||
node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
|
||||
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
|
||||
if fx_node_occurrence is None:
|
||||
fx_node_occurrence = {}
|
||||
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
|
||||
if k in expected_node_occurrence:
|
||||
fx_node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
|
||||
else:
|
||||
fx_node_occurrence = {
|
||||
ns.call_function(k): v for k, v in fx_node_occurrence.items()
|
||||
}
|
||||
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=fx_node_occurrence)
|
||||
fx_quant_output = m_fx(*example_inputs)
|
||||
self.assertEqual(fx_quant_output, pt2_quant_output)
|
||||
|
||||
if check_save_load:
|
||||
if example_inputs_for_load is None:
|
||||
assert isinstance(example_inputs, tuple) and \
|
||||
len(example_inputs) == 1 and \
|
||||
isinstance(example_inputs[0], torch.Tensor)
|
||||
example_inputs_for_load = (example_inputs[0] * 2,)
|
||||
|
||||
with TemporaryFileName() as fname, torchdynamo.config.patch(dynamic_shapes=export_with_dynamic_shape):
|
||||
torch.save(m.state_dict(), fname)
|
||||
ref_result = m(*example_inputs)
|
||||
|
||||
m_loaded = copy.deepcopy(m_eager)
|
||||
m_loaded, guards = torchdynamo.export(
|
||||
m_loaded,
|
||||
*copy.deepcopy(example_inputs_for_load),
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic" if export_with_dynamic_shape else "real",
|
||||
)
|
||||
|
||||
m_loaded = prepare_pt2e(m_loaded, quantizer)
|
||||
m_loaded(*example_inputs_for_load)
|
||||
m_loaded = convert_pt2e(m_loaded)
|
||||
m_loaded.load_state_dict(torch.load(fname))
|
||||
loaded_result = m_loaded(*example_inputs)
|
||||
self.assertTrue(torch.equal(ref_result, loaded_result))
|
||||
|
||||
def _verify_symmetric_qnnpack_qat_numerics(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
|
|
@ -583,14 +625,14 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
|
||||
|
|
@ -598,6 +640,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
BackendAQuantizer(),
|
||||
node_occurrence,
|
||||
node_list,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_wo_annotate_conv_output_quantizer(self):
|
||||
|
|
@ -666,12 +709,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
# Ensure the conv has no observer inserted at output
|
||||
node_occurrence = {
|
||||
# two for input of conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
|
|
@ -770,15 +813,15 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
# two for input of conv
|
||||
# one for input of maxpool
|
||||
# one for output of maxpool
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 4,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 4,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 4,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.max_pool2d_with_indices.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
|
|
@ -873,25 +916,25 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
node_occurrence = {
|
||||
# input, weight, bias, output for the conv
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
): 4,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
): 4,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
|
|
@ -949,40 +992,26 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m = convert_pt2e(m)
|
||||
fixed_scale = 1.0 / 256.0
|
||||
fixed_zero_point = 0
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "call_function":
|
||||
if (
|
||||
n.target
|
||||
== torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
):
|
||||
scale_0 = n.args[1]
|
||||
zero_point_0 = n.args[2]
|
||||
if (
|
||||
n.target
|
||||
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
):
|
||||
scale_1 = n.args[1]
|
||||
zero_point_1 = n.args[2]
|
||||
self.assertEqual(scale_0, fixed_scale)
|
||||
self.assertEqual(zero_point_0, fixed_zero_point)
|
||||
self.assertEqual(scale_1, fixed_scale)
|
||||
self.assertEqual(zero_point_1, fixed_zero_point)
|
||||
self.assertEqual(m._scale_0, fixed_scale)
|
||||
self.assertEqual(m._zero_point_0, fixed_zero_point)
|
||||
self.assertEqual(m._scale_1, fixed_scale)
|
||||
self.assertEqual(m._zero_point_1, fixed_zero_point)
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
): 2,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
),
|
||||
ns.call_function(torch.ops.aten.sigmoid.default),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
|
|
@ -996,16 +1025,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
|
||||
|
|
@ -1013,6 +1042,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_xnnpack_quantizer_linear(self):
|
||||
|
|
@ -1027,8 +1057,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs_4d = (torch.randn(9, 10, 11, 8),)
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
|
|
@ -1043,6 +1073,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_xnnpack_quantizer_conv_linear_no_permute(self):
|
||||
|
|
@ -1051,8 +1082,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
quantizer.set_global(operator_config)
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
|
|
@ -1068,6 +1099,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_xnnpack_quantizer_conv_linear(self):
|
||||
|
|
@ -1078,8 +1110,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
# Test with 2d inputs
|
||||
example_inputs = (torch.randn(2, 3, 4, 4),)
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
|
|
@ -1093,6 +1125,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_xnnpack_quantizer_linear_with_dynamic_shape(self):
|
||||
|
|
@ -1105,8 +1138,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs_3d = (torch.randn(9, 10, 8),)
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
|
|
@ -1121,6 +1154,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
True,
|
||||
qconfig_mapping,
|
||||
export_with_dynamic_shape=True,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_xnnpack_quantizer_obs_sharing_ops(self):
|
||||
|
|
@ -1131,28 +1165,28 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.mean.dim,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.hardtanh.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.mean.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
|
||||
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list, check_save_load=True,)
|
||||
|
||||
def test_propagate_annotation(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
|
|
@ -1180,10 +1214,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
): 5,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
): 5,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
|
|
@ -1235,6 +1269,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_xnnpack_quantizer_dynamic_linear_with_conv(self):
|
||||
|
|
@ -1247,6 +1282,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
|
||||
node_occurrence = {
|
||||
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
}
|
||||
# the weight q/dq is traced from reference quantized linear module as constant, so
|
||||
# we see quantize_per_tensor.default instead of quantize_per_tensor.tensor for weight
|
||||
fx_node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
|
|
@ -1275,6 +1316,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
fx_node_occurrence=fx_node_occurrence,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_composable_quantizer_linear_conv(self):
|
||||
|
|
@ -1295,10 +1338,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m_eager = TestHelperModules.ConvLinearWPermute().eval()
|
||||
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
# torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
# torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
|
|
@ -1331,6 +1374,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
False,
|
||||
qconfig_mapping,
|
||||
check_save_load=True,
|
||||
)
|
||||
|
||||
def test_composable_quantizer_throw(self):
|
||||
|
|
@ -1426,6 +1470,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
node_list,
|
||||
True,
|
||||
qconfig_mapping,
|
||||
check_save_load=True,
|
||||
example_inputs_for_load=(torch.zeros_like(indices),)
|
||||
)
|
||||
|
||||
def test_embedding_conv_linear_quantization(self):
|
||||
|
|
@ -1502,10 +1548,15 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
)
|
||||
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
# permute is not handled the same way in fx, so it will have 6 q/dqs
|
||||
fx_node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 6,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 6,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
|
|
@ -1517,6 +1568,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
[],
|
||||
True,
|
||||
qconfig_mapping,
|
||||
fx_node_occurrence=fx_node_occurrence,
|
||||
check_save_load=True,
|
||||
example_inputs_for_load=(torch.zeros_like(indices),)
|
||||
)
|
||||
|
||||
def test_prepare_qat_conv_bn_fusion(self):
|
||||
|
|
@ -1811,10 +1865,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
}
|
||||
non_ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
): 3,
|
||||
}
|
||||
self._test_representation(
|
||||
|
|
|
|||
|
|
@ -76,16 +76,16 @@ class TestQuantizePT2EFX(QuantizationTestCase):
|
|||
# first conv is quantized, second conv is not quantized
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
): 3,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
|
|
@ -129,16 +129,16 @@ class TestQuantizePT2EFX(QuantizationTestCase):
|
|||
# conv is quantized, linear is not quantized
|
||||
node_occurrence = {
|
||||
# two for input and weight of the conv, one for output for the conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
): 3,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor),
|
||||
ns.call_function(torch.ops.aten.addmm.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
||||
|
|
@ -212,8 +212,8 @@ class TestQuantizePT2EFX(QuantizationTestCase):
|
|||
m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
|
||||
expected_occurrence = {
|
||||
# for input and output activations
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 2,
|
||||
# weight is per channel quantized
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
||||
|
|
|
|||
|
|
@ -202,17 +202,17 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
|
|
@ -239,18 +239,18 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
|
|
@ -280,8 +280,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
# one for input and weight of the conv
|
||||
# one for output for the add
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
|
|
@ -292,18 +292,18 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
# one for output for the add
|
||||
# 2 conv will share same input quant/dequant
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
|
|
@ -335,8 +335,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
# one for input and weight of the conv
|
||||
# one for output for the relu
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 3,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
|
|
@ -347,18 +347,18 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
# one for output for the relu
|
||||
# 2 conv will share same input quant/dequant
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
|
|
@ -378,23 +378,23 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
example_inputs = (torch.randn(2, 3, 16, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 5,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.relu.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
|
|
|
|||
|
|
@ -157,8 +157,8 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
"_dtype_": dtype_
|
||||
}
|
||||
else:
|
||||
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
||||
|
|
@ -179,7 +179,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
for key, value_or_node in qparams.items():
|
||||
# TODO: we can add the information of whether a value needs to
|
||||
# be registered as an attribute in qparams dict itself
|
||||
if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))):
|
||||
if key in ['_scale_', '_zero_point_']:
|
||||
# For scale and zero_point values we register them as buffers in the root module.
|
||||
# However, note that when the values are not tensors, as in the case of
|
||||
# per_tensor quantization, they will be treated as literals.
|
||||
|
|
@ -379,6 +379,7 @@ def _replace_observer_with_quantize_dequantize_node(
|
|||
# be registered as an attribute in qparams dict itself
|
||||
if key in ['_scale_', '_zero_point_']:
|
||||
# For scale and zero_point values we register them as buffers in the root module.
|
||||
# this is needed to support serialization and deserialization for scale/zero_point
|
||||
# TODO: maybe need more complex attr name here
|
||||
qparam_node = create_getattr_from_value(
|
||||
model, graph, module_path + prefix + key, value_or_node)
|
||||
|
|
|
|||
|
|
@ -54,9 +54,8 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
|
|||
in the pattern.
|
||||
"""
|
||||
kwargs = {}
|
||||
if is_per_channel:
|
||||
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
|
||||
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
|
||||
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
|
||||
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
|
||||
if has_bias:
|
||||
kwargs["conv_bias"] = torch.randn(1)
|
||||
return kwargs
|
||||
|
|
@ -160,7 +159,7 @@ def _get_input_output_quantized_filter():
|
|||
if pattern_node.op == "placeholder":
|
||||
if (
|
||||
original_node.target
|
||||
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
== torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
):
|
||||
input_dq_node = original_node
|
||||
# output node is not a separate node in the list of nodes seen in the matçh
|
||||
|
|
@ -172,7 +171,7 @@ def _get_input_output_quantized_filter():
|
|||
output_node = list(original_node.users.keys())[0]
|
||||
if (
|
||||
output_node.target
|
||||
== torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
== torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
):
|
||||
output_q_node = original_node
|
||||
return (input_dq_node is not None) and (output_q_node is not None)
|
||||
|
|
@ -217,19 +216,19 @@ def _get_quantized_qat_conv2d_bn_pattern(
|
|||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
if is_per_channel:
|
||||
scaled_weight = torch.ops.quantized_decomposed.quantize_per_channel(
|
||||
scaled_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
|
||||
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
|
||||
weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_channel(
|
||||
scaled_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
|
||||
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
|
||||
weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
else:
|
||||
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
scaled_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
|
||||
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(
|
||||
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
scaled_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
|
||||
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(
|
||||
scaled_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
if has_bias:
|
||||
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
|
||||
|
|
@ -274,19 +273,19 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(
|
|||
) -> torch.Tensor:
|
||||
if is_per_channel:
|
||||
conv_weight = torch.ops.quantized_decomposed.quantize_per_channel(
|
||||
conv_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
|
||||
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
|
||||
weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
conv_weight = torch.ops.quantized_decomposed.dequantize_per_channel(
|
||||
conv_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
|
||||
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], per_channel_axis,
|
||||
weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
else:
|
||||
conv_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
conv_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
|
||||
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
conv_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
|
||||
conv_weight, kwargs["weight_scale"], kwargs["weight_zero_point"], weight_quant_min, weight_quant_max, torch.int8,
|
||||
)
|
||||
if has_bias:
|
||||
x = F.conv2d(x, conv_weight, kwargs["conv_bias"])
|
||||
|
|
@ -597,7 +596,7 @@ def _duplicate_dequantize_node(m: GraphModule):
|
|||
the dequantize node has users outside the matched portion of the graph.
|
||||
Instead, we match [dequantize_1 - a], which is safe.
|
||||
"""
|
||||
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
for n in m.graph.nodes:
|
||||
if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
|
||||
continue
|
||||
|
|
@ -615,7 +614,7 @@ def _remove_extra_dequantize(m: GraphModule):
|
|||
that can be shared across all the uses. This should be seen as the "reverse"
|
||||
of `_duplicate_dequantize_node`.
|
||||
"""
|
||||
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
for n in m.graph.nodes:
|
||||
dq_users = [user for user in n.users if user.op == "call_function" and user.target == dq_op]
|
||||
if len(dq_users) > 1:
|
||||
|
|
@ -694,31 +693,6 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
conv_bias = conv_node.args[2]
|
||||
assert conv_bias is None or isinstance(conv_bias, Node)
|
||||
|
||||
(weight_q_node, weight_dq_node) = _get_fused_convbn_q_dq_nodes(r.replacements)
|
||||
original_weight_q_node = None
|
||||
original_weight_dq_node = None
|
||||
for pattern_node, original_node in r.nodes_map.items():
|
||||
if pattern_node.op == 'placeholder':
|
||||
continue
|
||||
if (
|
||||
original_node.target
|
||||
== torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
):
|
||||
assert original_weight_q_node is None
|
||||
original_weight_q_node = original_node
|
||||
weight_q_node.args = (
|
||||
weight_q_node.args[:1] + original_weight_q_node.args[1:]
|
||||
)
|
||||
if (
|
||||
original_node.target
|
||||
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
):
|
||||
assert original_weight_dq_node is None
|
||||
original_weight_dq_node = original_node
|
||||
weight_dq_node.args = (
|
||||
weight_dq_node.args[:1] + original_weight_dq_node.args[1:]
|
||||
)
|
||||
|
||||
# fold bn weights into conv
|
||||
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user