Set correct output dtype for dequantize op during convert_pt2e in decomposed mode (#128953)

Earlier the signature of dequantize ops for decomposed quantized Tensor was changed for wider use-cases where the output dtype can be different from torch.float and needs to be passed during dequantization.
Please refer: https://github.com/pytorch/pytorch/pull/121450

However, setting of correct output dtype for dequantize ops was still missing in convert_pt2e flow.

This change enables the users to use PT2E quantization flow with non torch.float unquantized dtype, such as torch.bfloat16.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128953
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
kausik 2024-07-19 04:58:00 +00:00 committed by PyTorch MergeBot
parent d59803fb67
commit 4f60a2e39c
3 changed files with 54 additions and 13 deletions

View File

@ -1181,15 +1181,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
self.assertIsNot(observers[0], observers[2])
self.assertIsNot(observers[1], observers[2])
@parametrize("dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
def test_quantization_dtype(self, dtype):
@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
def test_quantization_dtype(self, dtype, quant_dtype):
class DtypeActQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
info_fun = torch.iinfo if dtype == torch.int16 else torch.finfo
info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo
activate_qspec = QuantizationSpec(
dtype=dtype,
quant_min=int(info_fun(dtype).min),
quant_max=int(info_fun(dtype).max),
dtype=quant_dtype,
quant_min=int(info_fun(quant_dtype).min),
quant_max=int(info_fun(quant_dtype).max),
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
@ -1214,9 +1215,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
pass
class M(torch.nn.Module):
def __init__(self):
def __init__(self, dtype):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype)
def forward(self, x):
return self.conv(x)
@ -1233,15 +1234,46 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_quantizer(
M().eval(),
example_inputs = (torch.randn(1, 3, 3, 3, dtype=dtype),)
m = self._test_quantizer(
M(dtype).eval(),
example_inputs,
quantizer,
node_occurrence,
node_list,
)
def verify_quant_dequant_iotypes(m):
for node in m.graph.nodes:
if (
node.op == "call_function"
and node.target.__name__ == "dequantize_per_tensor.default"
):
# Check dequantize node
dequant_node = node
dequant_in_dtype = dequant_node.args[5]
dequant_out_dtype = torch.float32
if "out_dtype" in dequant_node.kwargs:
dequant_out_dtype = dequant_node.kwargs["out_dtype"]
# Check preceding quantize node
# Depending on fold_quantize flag, quantize node may be absent
quant_node = node.args[0]
if (
quant_node.op == "call_function"
and quant_node.target.__name__ == "quantize_per_tensor.default"
):
quant_in_dtype = torch.float32
if "val" in quant_node.args[0].meta:
quant_in_dtype = quant_node.args[0].meta["val"].dtype
quant_out_dtype = quant_node.args[5]
assert (
quant_in_dtype == dequant_out_dtype
and quant_out_dtype == dequant_in_dtype
), "quant dequant io dtype check failed!"
verify_quant_dequant_iotypes(m)
def test_input_edge_sanity_check(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -149,6 +149,14 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
def add_dequantize_op_kwargs(dequantize_op, input_node):
dequantize_op_kwargs = {}
if 'val' in input_node.meta:
dq_out_dtype = input_node.meta['val'].dtype
if dq_out_dtype != torch.float32:
dequantize_op_kwargs = {"out_dtype": dq_out_dtype}
return dequantize_op_kwargs
if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
@ -219,7 +227,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
dequantized_node = graph.call_function(
dequantize_op,
tuple(dq_inputs),
{}
add_dequantize_op_kwargs(dequantize_op, input_node)
)
node.replace_all_uses_with(dequantized_node)
@ -322,7 +330,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
dequantized_node = graph.call_function(
dequantize_op,
tuple(dq_inputs),
{}
add_dequantize_op_kwargs(dequantize_op, input_node)
)
def remap_fn(x):

View File

@ -1306,6 +1306,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
fx_quant_output = m_fx(*example_inputs)
self.assertEqual(fx_quant_output, pt2_quant_output)
return m
def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
# resetting dynamo cache