mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Annotate None inputs in symbolic ops (#150038)
Add `None` to type annotations of `torch.onnx.ops.symbolic*` ops and improve tests to test support for optional inputs. Previously it was omitted mistakenly even though the implementation supports it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150038 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
6db95ccf4c
commit
3efa211e48
|
|
@ -145,7 +145,7 @@ class SymbolicOpsTest(common_utils.TestCase):
|
|||
def forward(self, x: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic(
|
||||
"custom_domain::CustomOp",
|
||||
(x,),
|
||||
(x, None),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
|
|
@ -289,7 +289,7 @@ class SymbolicOpsTest(common_utils.TestCase):
|
|||
def forward(self, x: torch.Tensor):
|
||||
return torch.onnx.ops.symbolic_multi_out(
|
||||
"custom_domain::CustomOp",
|
||||
(x,),
|
||||
(x, None),
|
||||
dict(
|
||||
int_key=1,
|
||||
float_key=1.0,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
|
|||
def symbolic(
|
||||
domain_op: str,
|
||||
/,
|
||||
inputs: Sequence[torch.Tensor],
|
||||
inputs: Sequence[torch.Tensor | None],
|
||||
attrs: dict[
|
||||
str,
|
||||
int
|
||||
|
|
@ -153,7 +153,7 @@ def symbolic(
|
|||
def symbolic_multi_out(
|
||||
domain_op: str,
|
||||
/,
|
||||
inputs: Sequence[torch.Tensor],
|
||||
inputs: Sequence[torch.Tensor | None],
|
||||
attrs: dict[
|
||||
str,
|
||||
int
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user