[ONNX] Add complex constant support (#138279)

Transform complex python constant to float representation as well, like what we have with tensors.

PS: I find it's not reasonable to add "complex->float" in IR side, so I put it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138279
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
Ti-Tai Wang 2024-10-22 19:42:56 +00:00 committed by PyTorch MergeBot
parent c7a20939b4
commit a71723bf12
3 changed files with 27 additions and 0 deletions

View File

@ -33,6 +33,20 @@ class DynamoExporterTest(common_utils.TestCase):
)
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
def test_constant_complex(self):
class MulModule(torch.nn.Module):
def forward(self, x):
y = 2 + 3j
return torch.ops.aten.mul(x, y)
# Example usage with complex inputs
x = torch.tensor(
[[1.0 + 2.0j, 3.0 + 4.0j], [5.0 + 6.0j, 7.0 + 8.0j]], dtype=torch.complex64
)
onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True)
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -207,6 +207,8 @@ def _determine_input_dtype(
return ir.DataType.STRING
if isinstance(arg, (ir.Tensor, ir.TensorProtocol)):
return arg.dtype
if isinstance(arg, complex):
return ir.DataType.FLOAT
if arg is None:
return ir.DataType.UNDEFINED
@ -261,9 +263,15 @@ def _get_or_create_constant(
dtype: ir.DataType,
opset: onnxscript.values.Opset,
) -> ir.Value:
# float representation of complex numbers
if isinstance(arg, complex):
# Convert the complex number to a float
arg = (arg.real, arg.imag)
if isinstance(arg, list):
# Make the arg hashable
arg = tuple(arg) # type: ignore[assignment]
constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type]
if constant_value is None:
constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type]

View File

@ -54,6 +54,11 @@ def assert_onnx_program(
kwargs = {}
torch_module = exported_program.module()
torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs))
# ONNX outputs are always real, so we need to convert torch complex outputs to real representations
torch_outputs = [
torch.view_as_real(output) if torch.is_complex(output) else output
for output in torch_outputs
]
onnx_outputs = program(*args, **kwargs)
# TODO(justinchuby): Include output names in the error message
torch.testing.assert_close(