mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Set correct output dtype for dequantize op during convert_pt2e in decomposed mode (#128953)
Earlier the signature of dequantize ops for decomposed quantized Tensor was changed for wider use-cases where the output dtype can be different from torch.float and needs to be passed during dequantization. Please refer: https://github.com/pytorch/pytorch/pull/121450 However, setting of correct output dtype for dequantize ops was still missing in convert_pt2e flow. This change enables the users to use PT2E quantization flow with non torch.float unquantized dtype, such as torch.bfloat16. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128953 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
d59803fb67
commit
4f60a2e39c
|
|
@ -1181,15 +1181,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
self.assertIsNot(observers[0], observers[2])
|
||||
self.assertIsNot(observers[1], observers[2])
|
||||
|
||||
@parametrize("dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
|
||||
def test_quantization_dtype(self, dtype):
|
||||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
@parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
|
||||
def test_quantization_dtype(self, dtype, quant_dtype):
|
||||
class DtypeActQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
info_fun = torch.iinfo if dtype == torch.int16 else torch.finfo
|
||||
info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo
|
||||
activate_qspec = QuantizationSpec(
|
||||
dtype=dtype,
|
||||
quant_min=int(info_fun(dtype).min),
|
||||
quant_max=int(info_fun(dtype).max),
|
||||
dtype=quant_dtype,
|
||||
quant_min=int(info_fun(quant_dtype).min),
|
||||
quant_max=int(info_fun(quant_dtype).max),
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
|
|
@ -1214,9 +1215,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
pass
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, dtype):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
|
@ -1233,15 +1234,46 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
]
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
self._test_quantizer(
|
||||
M().eval(),
|
||||
example_inputs = (torch.randn(1, 3, 3, 3, dtype=dtype),)
|
||||
m = self._test_quantizer(
|
||||
M(dtype).eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
def verify_quant_dequant_iotypes(m):
|
||||
for node in m.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target.__name__ == "dequantize_per_tensor.default"
|
||||
):
|
||||
# Check dequantize node
|
||||
dequant_node = node
|
||||
dequant_in_dtype = dequant_node.args[5]
|
||||
dequant_out_dtype = torch.float32
|
||||
if "out_dtype" in dequant_node.kwargs:
|
||||
dequant_out_dtype = dequant_node.kwargs["out_dtype"]
|
||||
|
||||
# Check preceding quantize node
|
||||
# Depending on fold_quantize flag, quantize node may be absent
|
||||
quant_node = node.args[0]
|
||||
if (
|
||||
quant_node.op == "call_function"
|
||||
and quant_node.target.__name__ == "quantize_per_tensor.default"
|
||||
):
|
||||
quant_in_dtype = torch.float32
|
||||
if "val" in quant_node.args[0].meta:
|
||||
quant_in_dtype = quant_node.args[0].meta["val"].dtype
|
||||
quant_out_dtype = quant_node.args[5]
|
||||
assert (
|
||||
quant_in_dtype == dequant_out_dtype
|
||||
and quant_out_dtype == dequant_in_dtype
|
||||
), "quant dequant io dtype check failed!"
|
||||
|
||||
verify_quant_dequant_iotypes(m)
|
||||
|
||||
def test_input_edge_sanity_check(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -149,6 +149,14 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
if hasattr(activation_post_process, "is_dynamic"):
|
||||
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
|
||||
|
||||
def add_dequantize_op_kwargs(dequantize_op, input_node):
|
||||
dequantize_op_kwargs = {}
|
||||
if 'val' in input_node.meta:
|
||||
dq_out_dtype = input_node.meta['val'].dtype
|
||||
if dq_out_dtype != torch.float32:
|
||||
dequantize_op_kwargs = {"out_dtype": dq_out_dtype}
|
||||
return dequantize_op_kwargs
|
||||
|
||||
if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
|
||||
# TODO: probably should cleanup this condition check, it's hard
|
||||
# to reason about this if and the following elif
|
||||
|
|
@ -219,7 +227,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
dequantized_node = graph.call_function(
|
||||
dequantize_op,
|
||||
tuple(dq_inputs),
|
||||
{}
|
||||
add_dequantize_op_kwargs(dequantize_op, input_node)
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(dequantized_node)
|
||||
|
|
@ -322,7 +330,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
dequantized_node = graph.call_function(
|
||||
dequantize_op,
|
||||
tuple(dq_inputs),
|
||||
{}
|
||||
add_dequantize_op_kwargs(dequantize_op, input_node)
|
||||
)
|
||||
|
||||
def remap_fn(x):
|
||||
|
|
|
|||
|
|
@ -1306,6 +1306,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
|
||||
fx_quant_output = m_fx(*example_inputs)
|
||||
self.assertEqual(fx_quant_output, pt2_quant_output)
|
||||
return m
|
||||
|
||||
def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
|
||||
# resetting dynamo cache
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user