pytorch/torch/onnx/ops/__init__.py
Justin Chu 3efa211e48 [ONNX] Annotate None inputs in symbolic ops (#150038)
Add `None` to type annotations of `torch.onnx.ops.symbolic*` ops and improve tests to test support for optional inputs. Previously it was omitted mistakenly even though the implementation supports it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150038
Approved by: https://github.com/titaiwangms
2025-03-27 00:01:09 +00:00

256 lines
8.9 KiB
Python

"""ONNX operators as native torch.fx operators.
This module provides a set of functions to create ONNX operators in the FX graph
which are exportable to ONNX.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch.onnx.ops import _symbolic_impl
if TYPE_CHECKING:
from collections.abc import Sequence
# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597
_TORCH_DTYPE_TO_ONNX_DTYPE = {
torch.float32: 1, # FLOAT
torch.uint8: 2, # UINT8
torch.int8: 3, # INT8
torch.uint16: 4, # UINT16
torch.int16: 5, # INT16
torch.int32: 6, # INT32
torch.int64: 7, # INT64
str: 8, # STRING
torch.bool: 9, # BOOL
torch.float16: 10, # FLOAT16
torch.double: 11, # DOUBLE
torch.uint32: 12, # UINT32
torch.uint64: 13, # UINT64
torch.complex64: 14, # COMPLEX64
torch.complex128: 15, # COMPLEX128
torch.bfloat16: 16, # BFLOAT16
torch.float8_e4m3fn: 17, # FLOAT8E4M3FN
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
torch.float8_e5m2: 19, # FLOAT8E5M2
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
}
def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
splitted = domain_op.split("::", 1)
if len(splitted) == 1:
domain = ""
op_type = splitted[0]
else:
domain = splitted[0]
op_type = splitted[1]
return domain, op_type
def symbolic(
domain_op: str,
/,
inputs: Sequence[torch.Tensor | None],
attrs: dict[
str,
int
| float
| str
| bool
| Sequence[int]
| Sequence[float]
| Sequence[str]
| Sequence[bool],
]
| None = None,
*,
dtype: torch.dtype | int,
shape: Sequence[int | torch.SymInt],
version: int | None = None,
metadata_props: dict[str, str] | None = None,
) -> torch.Tensor:
"""Create a symbolic FX operator to represent an arbitrary ONNX operator.
This function is used to create a symbolic operator with a single output.
To create an operator with multiple outputs, use :func:`symbolic_multi_out`.
Example::
class CustomOp(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Normal torch operators can interleave with the symbolic ops during ONNX export
x = x + 1
# Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
# The output tensor will have the specified dtype and shape
val = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x,),
dict(attr_key="attr_value"),
dtype=x.dtype,
shape=x.shape,
version=1,
)
# The result of the symbolic op can be used in normal torch operations during ONNX export
return torch.nn.functional.relu(val)
# You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
Args:
domain_op: The domain and operator name, separated by "::". For example,
"custom_domain::CustomOp".
inputs: The input tensors to the operator.
attrs: The attributes of the operator. The keys are attribute names and
the values are attribute values. Valid attribute types are int, float,
str, bool, and lists of int, float, str, and bool. Tensor attributes
are unsupported.
dtype: The data type of the output tensor.This can be either a torch.dtype
or an integer representing the ONNX data type.
shape: The shape of the output tensor. This can be a list of integers or
SymInt values.
version: The version of the opset used for the operator.
metadata_props: Metadata properties for the ONNX node.
This is a dictionary of str-str pairs.
Returns:
The output tensor of the operator.
"""
if not isinstance(dtype, int):
torch._check(
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}"
)
dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype]
domain, op_type = _parse_domain_op_type(domain_op)
if attrs is None:
attrs = {}
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
# TODO: Parse domain
return _symbolic_impl._symbolic(
inputs,
op_type,
dtype,
shape=shape,
attr_keys=encoded_attrs.attr_keys,
attr_types=encoded_attrs.attr_types,
attr_pos=encoded_attrs.attr_pos,
attr_ints=encoded_attrs.attr_ints,
attr_floats=encoded_attrs.attr_floats,
attr_strs=encoded_attrs.attr_strs,
metadata_props_keys=metadata_props.keys() if metadata_props else [],
metadata_props_values=metadata_props.values() if metadata_props else [],
domain=domain,
version=version,
)
def symbolic_multi_out(
domain_op: str,
/,
inputs: Sequence[torch.Tensor | None],
attrs: dict[
str,
int
| float
| str
| bool
| Sequence[int]
| Sequence[float]
| Sequence[str]
| Sequence[bool],
]
| None = None,
*,
dtypes: Sequence[torch.dtype | int],
shapes: Sequence[Sequence[int | torch.SymInt]],
version: int | None = None,
metadata_props: dict[str, str] | None = None,
) -> Sequence[torch.Tensor]:
"""Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs.
Example::
class CustomOp(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Normal torch operators can interleave with the symbolic ops during ONNX export
x = x + 1
# Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
# The output tensors will have the specified dtypes and shapes
(out1, out2) = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x,),
dict(attr_key="attr_value"),
dtypes=(x.dtype, torch.float32),
shapes=(x.shape, [1, 2, 3]),
version=1,
)
# The result of the symbolic op can be used in normal torch operations during ONNX export
return torch.nn.functional.relu(out1 + out2)
# You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
Args:
domain_op: The domain and operator name, separated by "::". For example,
"custom_domain::CustomOp".
inputs: The input tensors to the operator.
attrs: The attributes of the operator. The keys are attribute names and
the values are attribute values. Valid attribute types are int, float,
str, bool, and lists of int, float, str, and bool. Tensor attributes
are unsupported.
dtypes: The data types of the output tensors. This can be a list of
torch.dtype or integers representing the ONNX data types. The length
of this list must be the number of outputs.
shapes: The shapes of the output tensors. This can be a list of lists of
integers or SymInt values. The length of this list must be the number of outputs.
version: The version of the opset used for the operator.
metadata_props: Metadata properties for the ONNX node.
This is a dictionary of str-str pairs.
Returns:
A list of output tensors of the operator.
"""
torch._check(
len(shapes) == len(dtypes),
lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})",
)
onnx_dtypes = []
for dtype in dtypes:
if not isinstance(dtype, int):
torch._check(
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE,
lambda: f"Unsupported dtype: {dtype}",
)
onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype])
else:
onnx_dtypes.append(dtype)
domain, op_type = _parse_domain_op_type(domain_op)
if attrs is None:
attrs = {}
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
# Use the size of dtypes to determine the number of outputs
return _symbolic_impl._symbolic_multi_out(
inputs,
op_type,
onnx_dtypes,
shapes=shapes,
attr_keys=encoded_attrs.attr_keys,
attr_types=encoded_attrs.attr_types,
attr_pos=encoded_attrs.attr_pos,
attr_ints=encoded_attrs.attr_ints,
attr_floats=encoded_attrs.attr_floats,
attr_strs=encoded_attrs.attr_strs,
metadata_props_keys=metadata_props.keys() if metadata_props else [],
metadata_props_values=metadata_props.values() if metadata_props else [],
domain=domain,
version=version,
)