[ONNX] Annotate None inputs in symbolic ops (#150038)

Add `None` to type annotations of `torch.onnx.ops.symbolic*` ops and improve tests to test support for optional inputs. Previously it was omitted mistakenly even though the implementation supports it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150038
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-03-27 00:01:05 +00:00 committed by PyTorch MergeBot
parent 6db95ccf4c
commit 3efa211e48
2 changed files with 4 additions and 4 deletions

View File

@ -145,7 +145,7 @@ class SymbolicOpsTest(common_utils.TestCase):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x,),
(x, None),
dict(
int_key=1,
float_key=1.0,
@ -289,7 +289,7 @@ class SymbolicOpsTest(common_utils.TestCase):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x,),
(x, None),
dict(
int_key=1,
float_key=1.0,

View File

@ -55,7 +55,7 @@ def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
def symbolic(
domain_op: str,
/,
inputs: Sequence[torch.Tensor],
inputs: Sequence[torch.Tensor | None],
attrs: dict[
str,
int
@ -153,7 +153,7 @@ def symbolic(
def symbolic_multi_out(
domain_op: str,
/,
inputs: Sequence[torch.Tensor],
inputs: Sequence[torch.Tensor | None],
attrs: dict[
str,
int