mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add complex constant support (#138279)
Transform complex python constant to float representation as well, like what we have with tensors. PS: I find it's not reasonable to add "complex->float" in IR side, so I put it here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138279 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
parent
c7a20939b4
commit
a71723bf12
|
|
@ -33,6 +33,20 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
||||
|
||||
def test_constant_complex(self):
|
||||
class MulModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = 2 + 3j
|
||||
return torch.ops.aten.mul(x, y)
|
||||
|
||||
# Example usage with complex inputs
|
||||
x = torch.tensor(
|
||||
[[1.0 + 2.0j, 3.0 + 4.0j], [5.0 + 6.0j, 7.0 + 8.0j]], dtype=torch.complex64
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True)
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
|
|||
|
|
@ -207,6 +207,8 @@ def _determine_input_dtype(
|
|||
return ir.DataType.STRING
|
||||
if isinstance(arg, (ir.Tensor, ir.TensorProtocol)):
|
||||
return arg.dtype
|
||||
if isinstance(arg, complex):
|
||||
return ir.DataType.FLOAT
|
||||
if arg is None:
|
||||
return ir.DataType.UNDEFINED
|
||||
|
||||
|
|
@ -261,9 +263,15 @@ def _get_or_create_constant(
|
|||
dtype: ir.DataType,
|
||||
opset: onnxscript.values.Opset,
|
||||
) -> ir.Value:
|
||||
# float representation of complex numbers
|
||||
if isinstance(arg, complex):
|
||||
# Convert the complex number to a float
|
||||
arg = (arg.real, arg.imag)
|
||||
|
||||
if isinstance(arg, list):
|
||||
# Make the arg hashable
|
||||
arg = tuple(arg) # type: ignore[assignment]
|
||||
|
||||
constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type]
|
||||
if constant_value is None:
|
||||
constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -54,6 +54,11 @@ def assert_onnx_program(
|
|||
kwargs = {}
|
||||
torch_module = exported_program.module()
|
||||
torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs))
|
||||
# ONNX outputs are always real, so we need to convert torch complex outputs to real representations
|
||||
torch_outputs = [
|
||||
torch.view_as_real(output) if torch.is_complex(output) else output
|
||||
for output in torch_outputs
|
||||
]
|
||||
onnx_outputs = program(*args, **kwargs)
|
||||
# TODO(justinchuby): Include output names in the error message
|
||||
torch.testing.assert_close(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user