[ONNX] Add scaffolding for onnx decomp and logic for op tests (#147392)

Create scaffold for onnx op test data and common logic. This PR creates the scaffolding for new onnx decomp functions described in https://github.com/pytorch/pytorch/issues/139301. It adds two ops: abs and add, and enables the related tests.

https://github.com/pytorch/pytorch/issues/139301
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147392
Approved by: https://github.com/titaiwangms
ghstack dependencies: #147396
This commit is contained in:
Justin Chu 2025-02-19 11:23:01 -08:00 committed by PyTorch MergeBot
parent 24738768a8
commit 41ae15faa3
7 changed files with 1953 additions and 2 deletions

View File

@ -0,0 +1,80 @@
# Test op correctness by comparing with PyTorch results using OpInfo
`OpInfo` is PyTorch's standard mechanism for composing test data for operators.
Read more about them on https://github.com/pytorch/pytorch/blob/ce4a097bf769d753712a1fd969b446c59e29d8b9/torch/testing/_internal/opinfo/core.py#L362.
## Usage
```bash
# All
python -m pytest test_ops.py
# To run tests on a specific operator (e.g. torch.ceil):
python -m pytest test_ops.py -k ceil
# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
python -m pytest test_ops.py -k nn_functional_scaled_dot_product_attention
```
### Environment variables
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g.
```bash
CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/torchlib/test_ops.py -k div_mode_int
```
## How to add a new operator test
See _usage_ in [`ops_test_data.py`](./ops_test_data.py)
## How to add custom OpInfo tests
Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it.
Follow the steps below to create new OpInfo tests:
1. Use the implementation for `ops.aten.slice_scatter` as a reference (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2412-L2418) to declare an OpInfo in [`extra_opinfo.py`](./extra_opinfo.py)
```py
opinfo_core.OpInfo(
"ops.aten.slice_scatter",
aten_name="slice_scatter",
dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool),
sample_inputs_func=sample_inputs_slice_scatter,
supports_out=False,
),
```
- The first argument should be the operator name under the `torch.ops` namespace. For example, if you want to test the `prims.var` op, then put `"ops.prims.var"`. It should almost always start with `ops.`.
- Follow existing examples to specify the `dtypes` you want to test the op on.
- Specify `op=` if the target operator is not the same as the OpInfo name (first arg). For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2065-L2068.
```py
opinfo_core.OpInfo(
"ops.aten.bernoulli.p_deterministic",
op=torch.ops.aten.bernoulli.p,
```
The op is `torch.ops.aten.bernoulli.p`, which is different from the name `ops.aten.bernoulli.p_deterministic`. OpInfo names need to be globally unique in a test suite. When `op` is not specified, it will look for the op in `torch.` using its name.
2. Implement the `sample_inputs_func`. (Ref: https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L1242-L1268)
1. Copy the function and decide what the input shapes should be. Use `make_arg` to generate a torch.Tensor. Alternatively you could also use `torch.tensor` to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with
```py
yield opinfo_core.SampleInput(input, args=(...), kwargs={...})
```
`input` is the first arg. The rest of the args are in `args`.
3. Enable the test case in [`ops_test_data.py`](./ops_test_data.py)
1. Add a `TorchLibOpInfo` entry to the `TESTED_TORCHLIB_OPS` list. (For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L2116)
```py
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter)
```
You can additionally specify dtype tolerance (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L539) or conditional skips (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L586-L590).
Now that the test is added, you may run the test like mentioned above. Set `CREATE_REPRODUCTION_REPORT=1` to get markdown reports and view failing input combinations should any test case fails.

View File

@ -0,0 +1,701 @@
# Owner(s): ["module: onnx"]
"""Common utils for testing operators."""
from __future__ import annotations
import contextlib
import copy
import dataclasses
import multiprocessing
import os
import pprint
import sys
import unittest
import warnings
from typing import (
Any,
Callable,
Collection,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
)
import error_reproduction
import numpy as np
import onnx
import onnxruntime as ort
import onnxruntime.capi.onnxruntime_pybind11_state
import onnxscript
import onnxscript.evaluator
import pytest
from onnxscript import ir
import torch
from torch.onnx._internal.exporter import _building, _ir_passes, _tensors
from torch.testing._internal.opinfo import core as opinfo_core
T = TypeVar("T")
# Convenience tuples for creating dtype lists when skipping or xfailing tests
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
TEST_OPSET_VERSION = 18
IS_MACOS = sys.platform.startswith("darwin")
IS_WINDOWS = os.name == "nt"
@dataclasses.dataclass
class DecorateMeta:
"""A dataclass for storing information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
"""
op_name: str
variant_name: str
decorator: Callable[..., Any]
dtypes: Optional[Collection[torch.dtype]]
device_type: Optional[str]
reason: str
test_behavior: str
matcher: Optional[Callable[[Any], bool]] = None
enabled_if: bool = True
# The test_class_name to apply the decorator to. If None, the decorator is
# applied to all test classes.
test_class_name: Optional[str] = None
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> DecorateMeta:
"""Expects an OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
reason: The reason for the failure.
dtypes: The dtypes to expect the failure.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the xfail is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the xfail is enabled.
test_class_name: The test class name to apply the xfail to. If None, the
xfail is applied to all test classes.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
reason=reason,
enabled_if=enabled_if,
test_class_name=test_class_name,
test_behavior="xfail",
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> DecorateMeta:
"""Skips an OpInfo test.
Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
reason: The reason for skipping.
dtypes: The dtypes to skip.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the skip is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the skip is enabled.
test_class_name: The test class name to apply the skip to. If None, the skip
is applied to all test classes.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
dtypes=dtypes,
device_type=device_type,
reason=reason,
matcher=matcher,
enabled_if=enabled_if,
test_class_name=test_class_name,
test_behavior="skip",
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
skip_or_xfails: Iterable[DecorateMeta],
) -> Callable[[T], T]:
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list."""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
if opinfo is None and not decorate_meta.enabled_if:
# If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo
# because it could be an OpInfo that is in torch-nightly but not older versions.
continue
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
decorate_meta.test_class_name or test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
device_type=decorate_meta.device_type,
active_if=decorate_meta.enabled_if,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def duplicate_opinfo(
opinfos: list[opinfo_core.OpInfo], name: str, new_names: tuple[str, ...]
):
"""Duplicate an opinfo in the opinfo database and give it a new name."""
duplicated = []
all_info_names = {opinfo.name for opinfo in opinfos}
for opinfo in opinfos:
if opinfo.name == name:
for new_name in new_names:
if new_name in all_info_names:
# NOTE: Avoid duplicating an opinfo that already exists in the database.
# New opinfos are expected to be added in torch-nightly.
warnings.warn(
f"OpInfo {new_name} already exists in the database.",
stacklevel=1,
)
continue
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
duplicated.append(new_opinfo)
opinfos.extend(duplicated)
def duplicate_opinfo_for_prims(
opinfos: list[opinfo_core.OpInfo], name: str, prims_name: str | None = None
):
"""Duplicate an opinfo in the opinfo database for a prims op.
The function sets the new OpInfo to use the variation torch.ops.prims.
The new OpInfo will have the name "prims_{prims_name}" where `prims_name` is the
name of the prims op. If `prims_name` is None, it will be set to "prims_{name}".
Args:
opinfos: The list of opinfo_core.OpInfo to add the new opinfo to.
name: The name of the opinfo to duplicate.
prims_name: The name of the prims op. If None, it will be set to `name`.
"""
if prims_name is None:
prims_name = name
# The name of the new OpInfo
new_name = f"prims_{prims_name}"
all_info_names = {opinfo.name for opinfo in opinfos}
for opinfo in opinfos:
if opinfo.name == name:
if new_name in all_info_names:
# NOTE: Avoid duplicating an opinfo that already exists in the database.
warnings.warn(
f"OpInfo {new_name} already exists in the database.", stacklevel=1
)
continue
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
new_opinfo.op = getattr(torch.ops.prims, prims_name)
opinfos.append(new_opinfo)
return
raise RuntimeError(f"OpInfo '{name}' not found in the database.")
TORCH_TYPE_TO_ONNX = {
torch.bool: onnx.TensorProto.BOOL,
torch.uint8: onnx.TensorProto.UINT8,
torch.int8: onnx.TensorProto.INT8,
torch.int16: onnx.TensorProto.INT16,
torch.int32: onnx.TensorProto.INT32,
torch.int64: onnx.TensorProto.INT64,
torch.float16: onnx.TensorProto.FLOAT16,
torch.float32: onnx.TensorProto.FLOAT,
torch.float64: onnx.TensorProto.DOUBLE,
torch.complex64: onnx.TensorProto.COMPLEX64,
torch.complex128: onnx.TensorProto.COMPLEX128,
torch.bfloat16: onnx.TensorProto.BFLOAT16,
}
def convert_tensor_to_numpy(input: Any) -> Any:
if isinstance(input, torch.Tensor):
if torch.is_complex(input):
# from complex to real representation
input = torch.view_as_real(input)
return input.detach().cpu().numpy()
if isinstance(input, complex):
return torch.view_as_real(torch.tensor(input)).detach().cpu().numpy()
if isinstance(input, list):
if len(input) == 0:
return np.array((), dtype=np.int64)
if any(isinstance(x, torch.Tensor) for x in input):
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
return [convert_tensor_to_numpy(x) for x in input]
if isinstance(input[0], bool):
return np.array(input, dtype=np.bool_)
# Just a sequence of numbers
if isinstance(input[0], int):
return np.array(input, dtype=np.int64)
if isinstance(input[0], float):
return np.array(input)
return input
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Converts kwargs to be compatible with ONNX Runtime."""
new_kwargs = {}
for key, value in kwargs.items():
if key == "device":
continue
if key == "dtype":
value = TORCH_TYPE_TO_ONNX[value]
if isinstance(value, torch.Tensor):
value = np.array(value.cpu())
new_kwargs[key] = value
return new_kwargs
class OrtAbortedError(RuntimeError):
"""ONNX Runtime Aborted."""
def _ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]):
"""Run a model with ONNX Runtime."""
# Disable all ORT optimizations
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
)
session = ort.InferenceSession(
serialized_model, session_options, providers=("CPUExecutionProvider",)
)
return session.run(None, ort_inputs)
def _ort_session_run_return_dict(
serialized_model: bytes, ort_inputs: Mapping[str, Any], return_dict
) -> None:
"""Run a model with ONNX Runtime and store the results in return_dict."""
try:
return_dict["results"] = _ort_session_run(serialized_model, ort_inputs)
return_dict["error"] = None
except Exception as e: # pylint: disable=broad-except
return_dict["results"] = None
return_dict["error"] = e
def _safe_ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]):
"""Run a model with ONNX Runtime in a separate process.
Args:
serialized_model: Serialized ONNX model proto.
ort_inputs: Inputs to the model.
Returns:
The inference result.
Raises:
OrtAbortedError if the process did not execute successfully.
"""
manager = multiprocessing.Manager()
return_dict = manager.dict()
process = multiprocessing.Process(
target=_ort_session_run_return_dict,
args=(serialized_model, ort_inputs, return_dict),
)
process.start()
process.join()
process.close()
if not return_dict:
raise OrtAbortedError
if return_dict["error"] is not None:
raise return_dict["error"]
return return_dict["results"]
def _format_model_and_input_information(onnx_model, inputs):
return (
f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}"
)
_TORCH_DTYPE_TO_ONNX_STRING = {
torch.bool: "tensor(bool)",
torch.uint8: "tensor(uint8)",
torch.int8: "tensor(int8)",
torch.int16: "tensor(int16)",
torch.int32: "tensor(int32)",
torch.int64: "tensor(int64)",
torch.float16: "tensor(float16)",
torch.float32: "tensor(float)",
torch.float64: "tensor(double)",
torch.complex64: "tensor(complex64)",
torch.complex128: "tensor(complex128)",
torch.bfloat16: "tensor(bfloat16)",
}
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.COMPLEX128,
torch.complex64: ir.DataType.COMPLEX64,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
torch.uint16: ir.DataType.UINT16,
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
"""Checks if the dtype is compatible with the schema.
When a dtype is "compatible" with the schema, it means we can use the dtype
to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.
Args:
dtype: The torch dtype used to create sample inputs by OpInfo.
schema: The ONNX schema of the function.
Returns:
True if the dtype is compatible with the schema.
"""
if not schema.inputs:
# If there are no inputs, we can't check compatibility. Assume it is compatible.
# e.g. aten_randn has only attributes.
return True
if schema.inputs[0].name not in {"self", "input"}:
# If the name of the first input is not "self" or "input",
# it is usually an input that is not of the same type as the output.
# We assume support in this case.
#
# For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)`
# has the first input as `size`, which is an integer, but it can support
# any dtype.
return True
# Otherwise we check the type constraints of the first input.
# For example, when dtype=torch.float32, and the op being tested has the schema
# ```
# OpSchema(
# name='aten_abs',
# domain='pkg.onnxscript.torch_lib',
# since_version=1,
# doc='abs(Tensor self) -> Tensor',
# type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal',
# allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)',
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
# 'tensor(bfloat16)'], description='')],
# inputs=[OpSchema.FormalParameter(name='self', type_str='TReal',
# description='', param_option=<FormalParameterOption.Single: 0>,
# is_homogeneous=True, min_arity=1,
# differentiation_category=<DifferentiationCategory.Unknown: 0>)],
# outputs=[OpSchema.FormalParameter(name='return_val',
# type_str='TReal', description='',
# param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True,
# min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
# attributes={}
# )
# ```
# we see the first input type is "TReal", corresponding to the type constraint
# with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)',
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
# 'tensor(bfloat16)'].
# Since torch.float32 (tensor(float)) is in the allowed types, we return True.
first_input_type_name = schema.inputs[0].type_str
# Find the type constraint for the first input by matching the parameter name
first_input_type_constraint = next(
(
x
for x in schema.type_constraints
if first_input_type_name in x.type_param_str
),
None,
)
assert first_input_type_constraint is not None
allowed_type_strs = first_input_type_constraint.allowed_type_strs
# Here we consider seq(tensor(float)) compatible with tensor(float) as well
return any(
_TORCH_DTYPE_TO_ONNX_STRING[dtype] in type_str for type_str in allowed_type_strs
)
def graph_executor(
test_name: str,
outputs: Sequence[Any],
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
"""Eagerly executes a function."""
def _capture_graph_and_evaluate_torch_script_evaluator(
function: Callable, args, kwargs
) -> tuple[Any, onnx.ModelProto]:
"""Captures the graph of a function and evaluates it using TorchScriptEvaluator."""
# Initialize the ONNX graph
graph = ir.Graph(
(),
(),
nodes=(),
opset_imports={"": 18, "pkg.torch.onnx": 1},
name="main_graph",
)
opset = onnxscript.opset18
tracer = _building.OpRecorder(opset, {})
ort_inputs = {}
onnxscript_args: list[Any] = []
onnxscript_kwargs = {}
for i, arg in enumerate(args):
if isinstance(arg, np.ndarray):
input_name = f"input_{i}"
input = _tensors.SymbolicTensor(
opset=opset,
name=input_name,
shape=ir.Shape(arg.shape),
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(arg).dtype]),
)
graph.inputs.append(input)
onnxscript_args.append(input)
ort_inputs[input_name] = arg
elif isinstance(arg, (list, tuple)):
# str is also a sequence but we do not want to treat it as a tensor
sequence_input = []
for j, subarg in enumerate(arg):
if isinstance(subarg, np.ndarray):
input_name = f"input_{i}_{j}"
tensor = torch.tensor(subarg)
input = _tensors.SymbolicTensor(
opset=opset,
name=input_name,
shape=ir.Shape(tensor.shape),
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[tensor.dtype]),
)
graph.inputs.append(input)
sequence_input.append(input)
ort_inputs[input_name] = subarg
else:
# Include non-numpy inputs as-is
# For example, it could be a None value that we want to keep
sequence_input.append(subarg)
onnxscript_args.append(sequence_input)
else:
onnxscript_args.append(arg)
for key, value in kwargs.items():
if isinstance(value, np.ndarray):
input = _tensors.SymbolicTensor(
opset=opset,
name=key,
shape=ir.Shape(torch.tensor(value).shape),
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(value).dtype]),
)
graph.inputs.append(input)
ort_inputs[key] = value
onnxscript_kwargs[key] = input
else:
onnxscript_kwargs[key] = value
with onnxscript.evaluator.default_as(tracer):
symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs)
if not isinstance(symbolic_outputs, Sequence):
symbolic_outputs = (symbolic_outputs,)
# We need to set the size of the output tensors for the ONNX model to be valid
for output, symbolic_output in zip(outputs, symbolic_outputs):
if isinstance(output, Sequence):
# Output is a sequence
elem_dtype = _TORCH_DTYPE_TO_ONNX[output[0].dtype]
symbolic_output.type = ir.SequenceType(ir.TensorType(elem_dtype))
continue
output = (
output
if isinstance(output, torch.Tensor)
else torch.tensor(output, device="cpu")
)
symbolic_output.shape = ir.Shape(output.shape)
symbolic_output.dtype = _TORCH_DTYPE_TO_ONNX[output.dtype]
graph.outputs.extend(symbolic_outputs)
graph.extend(tracer.nodes)
onnx_model = ir.Model(graph, ir_version=10, producer_name="torch_test")
for identifier, onnxscript_function in tracer.functions.items():
if identifier in onnx_model.functions:
continue
if isinstance(onnxscript_function, ir.Function):
ir_function = onnxscript_function
else:
# TODO: Get IR function directly when onnxscript is updated
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
_ir_passes.add_torchlib_common_imports(onnx_model)
_ir_passes.add_opset_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)
try:
onnx.checker.check_model(model_proto, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
model_proto = onnx.shape_inference.infer_shapes(model_proto, data_prop=True)
try:
if (
os.environ.get("CATCH_ORT_SEGFAULT") == "1"
or os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
):
# Use an individual process to run ONNX Runtime to catch segfaults
return _safe_ort_session_run(
model_proto.SerializeToString(), ort_inputs
), model_proto
return _ort_session_run(
model_proto.SerializeToString(), ort_inputs
), model_proto
except (
# pylint: disable=c-extension-no-member
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented,
# pylint: enable=c-extension-no-member
) as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_reproduction_report(
test_name,
model_proto,
ort_inputs,
e,
"test/onnx/torchlib/test_ops.py",
)
raise RuntimeError(
"ONNX Runtime failed to evaluate:\n"
+ _format_model_and_input_information(model_proto, ort_inputs)
) from e
except OrtAbortedError as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
# Save the model and inputs to a file for reproduction
error_reproduction.create_reproduction_report(
test_name,
model_proto,
ort_inputs,
e,
"test/onnx/torchlib/test_ops.py",
)
raise OrtAbortedError(
"ONNX Runtime aborted:\n"
+ _format_model_and_input_information(model_proto, ort_inputs)
) from e
except Exception as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_reproduction_report(
test_name,
model_proto,
ort_inputs,
e,
"test/onnx/torchlib/test_ops.py",
)
raise
return _capture_graph_and_evaluate_torch_script_evaluator
@contextlib.contextmanager
def normal_xfail_skip_test_behaviors(
test_behavior: Optional[str] = None, reason: Optional[str] = None
):
"""This context manager is used to handle the different behaviors of xfail and skip.
Args:
test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None.
reason (optional[str]): The reason for the failure or skip.
Raises:
e: Any exception raised by the test case if it's not an expected failure.
"""
# We need to skip as soon as possible, as SegFault might also be a case.
if test_behavior == "skip":
pytest.skip(reason=reason)
try:
yield
# We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
# to go over all test cases to find the right exception type.
except Exception: # pylint: disable=broad-exception-caught
if test_behavior is None:
raise
if test_behavior == "xfail":
pytest.xfail(reason=reason)
else:
if test_behavior == "xfail":
pytest.fail("Test unexpectedly passed")

View File

@ -0,0 +1,691 @@
# Owner(s): ["module: onnx"]
"""Test op correctness by comparing with PyTorch results.
## Usage
1. Set the env var CATCH_ORT_SEGFAULT to catch segfaults from ONNX Runtime.
## How to add a new operator test
This test use PyTorch's OpInfo mechanism to generate test cases for each operator.
You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804
1. To enable test cases for an operator
Add a `TorchLibOpInfo` entry to `TORCH_LIB_OPINFO` in `ops_test_data.py`.
Specify `complex` if the function is designed for complex inputs.
The `op_info_name` in `TorchLibOpInfo` needs to be unique in the TORCH_LIB_OPINFO
list, but complex=True ops can share the same name with non-complex ops
because they are tested separately.
2. Add `.skip` and/or `.xfail` to skip or xfail tests.
Prefer xfail over skip when possible because that allows us to monitor the behavior
and update the test will it passes.
2a. If a test is now failing because of xpass, because some previous errors
are now fixed, removed the corresponding xfail.
3. If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input
wrangler function. See `_mean_input_wrangler` for an example.
4. To test different ONNX functions that are registered as overloads of the same
op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each
to one overload.
"""
# flake8: noqa
from __future__ import annotations
import copy
import dataclasses
import functools
from typing import Any, Callable, Collection, Optional
from typing_extensions import Self
import numpy as np
import ops_test_common
import torch
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
from torch.testing._internal import common_methods_invocations
from torch.testing._internal.opinfo import definitions as opinfo_definitions
# Create a copy of the op_db to modify
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
# Append extra op_db into the op database for testing
OPS_DB.extend(opinfo_definitions.signal.op_db)
@dataclasses.dataclass
class TorchLibOpInfo:
"""A dataclass to store the information to test an torchlib op."""
# The name of the op_info, e.g. "add"
op_info_name: str
# The torchlib ONNX Function to test
op: Callable[..., Any]
# The input wrangler function to adjust the input to fit the aten signature
input_wrangler: Optional[
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]]
] = None
# Whether the op is non-deterministic
nondeterministic: bool = False
# Whether to compare the shape only for the output[index]
# For example: (1,2) means compare value for output[0] and shape for output[1] and [2]
# We may be able to combine this with the nondeterministic option
compare_shape_only_for_output: tuple[int, ...] = ()
# Whether the function is designed for complex inputs
complex: bool = False
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
# Format: {dtype: (rtol, atol)}.
# For example: {torch.float16: (1e-3, 1e-3)}
tolerance: dict[torch.dtype, tuple[float, float]] = dataclasses.field(
default_factory=dict
)
# Expected skips or fails for the test and/or subtests
skips_or_fails: list[ops_test_common.DecorateMeta] = dataclasses.field(
default_factory=list
)
def get_tolerance(self, dtype: torch.dtype) -> tuple[float | None, float | None]:
"""Returns the (rtol, atol) tolerance for the given dtype."""
if (tolerance := self.tolerance.get(dtype)) is not None:
return tolerance
# Use the PyTorch default if not specified
# https://pytorch.org/docs/stable/testing.html
return (None, None)
def skip(
self,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> Self:
"""Skips an OpInfo test.
Args:
variant_name: Optional OpInfo variant_test_name.
reason: The reason for skipping.
dtypes: The dtypes to skip.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the skip is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the skip is enabled.
test_class_name: The test class name to apply the skip to. If None, the skip
is applied to all test classes.
"""
self.skips_or_fails.append(
ops_test_common.skip(
self.op_info_name,
variant_name,
reason=reason,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
enabled_if=enabled_if,
test_class_name=test_class_name,
)
)
return self
def xfail(
self,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> Self:
"""Expects an OpInfo test to fail.
Args:
variant_name: Optional OpInfo variant_test_name.
reason: The reason for the failure.
dtypes: The dtypes to expect the failure
device_type: Device type. E.g. "cpu", "cuda"..
matcher: A function that matches the test sample input. It is used only when
the xfail is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the xfail is enabled.
test_class_name: The test class name to apply the xfail to. If None, the
xfail is applied to all test classes.
"""
self.skips_or_fails.append(
ops_test_common.xfail(
self.op_info_name,
variant_name,
reason=reason,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
enabled_if=enabled_if,
test_class_name=test_class_name,
)
)
return self
# Modify this section ##########################################################
def _amin_amax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "dim" not in kwargs:
# Supply an empty dim to match the aten signature
kwargs["dim"] = np.array([], dtype=np.int64)
else:
# Convert dim to a numpy array
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64).reshape((-1,))
return args, kwargs
def _avg_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "dim" not in kwargs:
if len(args) > 6:
kwargs["divisor_override"] = args.pop(6)
if len(args) > 5:
kwargs["count_include_pad"] = args.pop(5)
if len(args) > 4:
kwargs["ceil_mode"] = args.pop(4)
if len(args) > 3:
padding = args.pop(3)
if isinstance(padding, np.ndarray):
# Cannot using list(padding) here, because the element will be numpy.int64 instead of int
padding = padding.tolist()
kwargs["padding"] = padding
if len(args) > 2:
stride = args.pop(2)
if isinstance(stride, np.ndarray):
stride = stride.tolist()
kwargs["stride"] = stride
kernel_size = args.pop(1)
if isinstance(kernel_size, np.ndarray):
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size
return args, kwargs
def _cross_entropy_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
reduction_vals = ["none", "mean", "sum"]
value = kwargs["reduction"]
idx = reduction_vals.index(value)
kwargs["reduction"] = idx
return args, kwargs
def _dropout_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "training" in kwargs:
kwargs["train"] = kwargs["training"]
kwargs.pop("training")
return args, kwargs
def _einsum_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Swap the equation and tensors to revert the special handling in the OpInfo
return [args[1], args[0]], kwargs
def _embedding_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
kwargs.pop("max_norm", None)
kwargs.pop("norm_type", None)
return args, kwargs
def _empty_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
kwargs.pop("requires_grad", None)
return args, kwargs
def _grid_sample_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Convert string attriute to int as input
inter_mode_options = {"bilinear": 0, "nearest": 1, "bicubic": 2}
padding_mode_options = {"zeros": 0, "border": 1, "reflection": 2}
args.append(inter_mode_options[kwargs["mode"]])
args.append(padding_mode_options[kwargs["padding_mode"]])
args.append(kwargs["align_corners"])
kwargs.clear()
return args, kwargs
def _im2col_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Move kernel_size, dilation, padding and stride from args to kwargs
if len(args) == 5:
# Handle stride
stride = args.pop()
if isinstance(stride, np.ndarray): # convert stride to list[int]
stride = stride.tolist()
kwargs["stride"] = stride
# Handle padding
padding = args.pop()
if isinstance(padding, np.ndarray): # convert padding to list[int]
padding = padding.tolist()
kwargs["padding"] = padding
# Handle dilation
dilation = args.pop()
if isinstance(dilation, np.ndarray): # convert dilation to list[int]
dilation = dilation.tolist()
kwargs["dilation"] = dilation
# Handle kernel_size
kernel_size = args.pop()
if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int]
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size
return args, kwargs
def _index_put_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = [np.array(elem) for elem in args[1]]
return args, kwargs
def _max_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Remove return_indices argument because this op doesn't accept it
kwargs.pop("return_indices", None)
return args, kwargs
def _mean_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Make the dims as tensor
if "dim" in kwargs:
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
return args, kwargs
def _mse_loss_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
reduction_vals = ["none", "mean", "sum"] # [0,1,2], default=1
value = kwargs["reduction"]
idx = reduction_vals.index(value)
kwargs["reduction"] = idx
return args, kwargs
def _nll_loss_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
# aten_nll_loss can only accept integer argument instead of string
reduction_vals = ["none", "mean", "sum"]
value = kwargs["reduction"]
kwargs["reduction"] = reduction_vals.index(value)
return args, kwargs
def _nonzero_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs.pop("as_tuple", None)
return args, kwargs
def _reflection_pad2d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args.pop(2) # remove 'reflect' arg
return args, kwargs
def _replication_pad2d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args.pop(2) # remove 'replicate' arg
return args, kwargs
def _replication_pad3d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args.pop(2) # remove 'replicate' arg
return args, kwargs
def _roll_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if len(args) >= 3:
if isinstance(args[2], np.ndarray): # convert dims to list[int]
# Change dims from args to kwargs to keep tuple/list type
dims = args.pop(2)
kwargs["dims"] = dims.tolist()
elif isinstance(args[2], int): # convert dims to list[int]
dims = args.pop(2)
kwargs["dims"] = []
kwargs["dims"].append(dims)
if len(args) >= 2:
if isinstance(args[1], np.ndarray): # convert shift to list[int]
shifts = args.pop(1)
kwargs["shifts"] = shifts.tolist()
elif isinstance(args[1], int):
shifts = args.pop(1)
kwargs["shifts"] = []
kwargs["shifts"].append(shifts)
return args, kwargs
def _scalar_tensor_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs.pop("requires_grad", None)
return args, kwargs
def _scatter_reduce_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Put the string into kwargs, otherwise FullGraph mode could not find get 'reduce' argument
kwargs["reduce"] = args.pop(4)
return args, kwargs
def _sum_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if kwargs.get("dim") is not None:
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
return args, kwargs
def _unflatten_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = np.array(args[1], dtype=np.int64)
return args, kwargs
def _where_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# The aten::where op takes condition, x, y as inputs
# Swap the first two inputs
args[0], args[1] = args[1], args[0]
return args, kwargs
# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo("abs", core_ops.aten_abs),
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
)
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
ops_test_common.duplicate_opinfo(
OPS_DB, "arange", ("arange_start", "arange_start_step")
)
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",))
ops_test_common.duplicate_opinfo(
OPS_DB,
"bitwise_left_shift",
(
"bitwise_left_shift_int8",
"bitwise_left_shift_int16",
"bitwise_left_shift_int32",
"bitwise_left_shift_int64",
),
)
ops_test_common.duplicate_opinfo(
OPS_DB,
"bitwise_right_shift",
(
"bitwise_right_shift_int8",
"bitwise_right_shift_int16",
"bitwise_right_shift_int32",
"bitwise_right_shift_int64",
),
)
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int"))
ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",))
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.pad",
(
"nn.functional.reflection_pad2d",
"nn.functional.replication_pad2d",
"nn.functional.replication_pad3d",
),
)
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.scaled_dot_product_attention",
("nn.functional.scaled_dot_product_attention_bool_mask",),
)
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.celu",
("nn.functional.celu_type_promoted",),
)
ops_test_common.duplicate_opinfo(
OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",)
)
ops_test_common.duplicate_opinfo(
OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)
)
ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",))
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",))
# MARK: End edits here
# These ops are not deterministic, so we check shape and dtype only
NONDETERMINISTIC_OPS: frozenset[str] = frozenset(
info.op_info_name for info in TESTED_TORCHLIB_OPS if info.nondeterministic
)
COMPARE_SHAPE_ONLY_OPS: dict[
str,
set,
] = {
info.op_info_name: set(info.compare_shape_only_for_output)
for info in TESTED_TORCHLIB_OPS
}
TORCHLIB_OPINFO_MAPPING: dict[
str,
TorchLibOpInfo,
] = {info.op_info_name: info for info in TESTED_TORCHLIB_OPS if not info.complex}
TESTED_OPS = frozenset(TORCHLIB_OPINFO_MAPPING)
EXPECTED_SKIPS_OR_FAILS: tuple[ops_test_common.DecorateMeta, ...] = tuple(
functools.reduce(
# Flatten the list
lambda a, b: [*a, *b],
[
[meta for meta in info.skips_or_fails if meta.matcher is None]
for info in TESTED_TORCHLIB_OPS
],
)
)
SKIP_XFAIL_SUBTESTS: tuple[ops_test_common.DecorateMeta, ...] = tuple(
functools.reduce(
# Flatten the list
lambda a, b: [*a, *b],
[
[meta for meta in info.skips_or_fails if meta.matcher is not None]
for info in TESTED_TORCHLIB_OPS
],
)
)
# MARK: Complex supported functions
COMPLEX_FUNCTION_MAPPING: dict[
str,
TorchLibOpInfo,
] = {info.op_info_name: info for info in TESTED_TORCHLIB_OPS if info.complex}
# Call dir(torch.ops.prims) and compare with entries in OPS_DB to create OpInfo for newly added prims ops
PRIMS_OPS_WITH_OP_INFO = (
"abs",
"acos",
"acosh",
"add",
"amax",
"amin",
"as_strided",
"as_strided_scatter",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bitwise_xor",
"cat",
"ceil",
"clone",
"conj",
"conj_physical",
"cos",
"cosh",
"digamma",
"div",
"empty",
"eq",
"erf",
"erfc",
"exp",
"exp2",
"expm1",
"fill",
"floor",
"fmax",
"fmin",
"fmod",
"full",
"full_like",
"gcd",
"ge",
"gt",
"hypot",
"igamma",
"igammac",
"imag",
"isfinite",
"le",
"lgamma",
"log",
"log10",
"log1p",
"log2",
"lt",
"maximum",
"minimum",
"mul",
"ne",
"neg",
"nextafter",
"normal",
"pow",
"prod",
"real",
"reciprocal",
"remainder",
"reshape",
"round",
"rsqrt",
"scalar_tensor",
"sign",
"signbit",
"sin",
"sinh",
"sqrt",
"squeeze",
"sub",
"sum",
"svd",
"tan",
"tanh",
"transpose",
"trunc",
"uniform",
"where",
)
for op in PRIMS_OPS_WITH_OP_INFO:
# Duplicate opinfo for prim ops. The new names all start with "prims_". E.g. "abs" -> "prims_abs".
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, op)
# Duplicate cases where the prims op name is different from the torch op name
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "i0", "bessel_i0")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.bessel_j0", "bessel_j0")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.bessel_j1", "bessel_j1")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.erfcx", "erfcx")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i0e", "bessel_i0e")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i1", "bessel_i1")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i1e", "bessel_i1e")
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.ndtri", "ndtri")
ops_test_common.duplicate_opinfo_for_prims(
OPS_DB, "special.spherical_bessel_j0", "spherical_bessel_j0"
)
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.zeta", "zeta")
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
assert NONDETERMINISTIC_OPS.issubset(
TESTED_OPS
), f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS"

View File

@ -0,0 +1,354 @@
# Owner(s): ["module: onnx"]
"""Test op correctness by comparing with PyTorch results.
Usage:
pytest test_ops.py
To run tests on a specific operator (e.g. torch.ceil):
pytest test_ops.py -k ceil
To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
pytest test_ops.py -k nn_functional_scaled_dot_product_attention
## Environment variables
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
errors.
"""
from __future__ import annotations
import os
from typing import Callable, Optional, Sequence, Tuple, TYPE_CHECKING
import error_reproduction
import numpy as np
import onnx
import onnxruntime as ort
import onnxscript
import ops_test_common
import ops_test_data
import parameterized
import torch
from torch.testing._internal import common_device_type, common_utils
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
import unittest
from torch.testing._internal.opinfo import core as opinfo_core
# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
TESTED_DTYPES = (
torch.float16,
torch.float32,
# Uncomment below item when we really need testing it
# torch.bfloat16,
# torch.float64,
torch.bool,
# torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
)
# NOTE: torch.complex32 is experimental in torch
COMPLEX_TYPES = (torch.complex64,)
def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
"""Returns all dtypes except the ones specified."""
return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)
def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None
for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if not decorator_meta.enabled_if:
# Do not skip the test if the decorator meta is not enabled
continue
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if (
decorator_meta.device_type is not None
and decorator_meta.device_type != device_type
):
# Not applicable for this device_type
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
class TestFunctionValidity(common_utils.TestCase):
@parameterized.parameterized.expand(
[
(info.op.name, info)
for info in ops_test_data.TESTED_TORCHLIB_OPS
if isinstance(info.op, onnxscript.OnnxFunction)
],
skip_on_empty=True,
)
def test_script_function_passes_checker(
self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
):
function_proto = torchlib_op_info.op.to_function_proto()
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
def run_test_output_match(
test_suite: unittest.TestCase,
device: str,
dtype: torch.dtype,
op: opinfo_core.OpInfo,
function_executor: Callable,
tested_op_mapping: dict[
str,
ops_test_data.TorchLibOpInfo,
],
):
"""Base test method for testing each opset, used by instantiate_device_type_tests.
Args:
test_suite: The test class instance.
device: The PyTorch device. instantiate_device_type_tests provides this.
dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
op: The OpInfo instance. instantiate_device_type_tests provides this.
function_executor: The function executor. This is a function that takes
a function and its arguments and returns the output of the function.
tested_op_mapping: The mapping of op name to the tested op.
"""
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
torchlib_op_info = tested_op_mapping[op.name]
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function = torchlib_op_info.op
input_wrangler = torchlib_op_info.input_wrangler
if (
not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
and dtype not in COMPLEX_TYPES
):
test_suite.skipTest(
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
f"Type constraints: {onnx_function.op_schema.type_constraints}"
)
# Obtain the tolerance for the op
rtol, atol = torchlib_op_info.get_tolerance(dtype)
for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with test_suite.subTest(
sample_num=i,
inputs=repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
if isinstance(inp, torch.Tensor)
else inp
for inp in inputs
]
),
kwargs=repr(cpu_sample.kwargs),
):
try:
device_type = cpu_sample.args[0].device.type
except (AttributeError, IndexError):
device_type = "cpu"
test_behavior, reason = _should_skip_xfail_test_sample(
op.name, cpu_sample, dtype, device_type
)
with ops_test_common.normal_xfail_skip_test_behaviors(
test_behavior, reason
):
input_onnx = [
ops_test_common.convert_tensor_to_numpy(x) for x in inputs
]
kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
if input_wrangler:
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)
if isinstance(torch_output, torch.Tensor) and torch.is_complex(
torch_output
):
torch_output = torch.view_as_real(torch_output.resolve_conj())
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
if (
op.name.startswith("split")
or op.name.startswith("chunk")
or op.name.startswith("unbind")
or op.name
in {
"atleast_1d_Sequence",
"atleast_2d_Sequence",
"atleast_3d_Sequence",
}
):
# Hack for handling split, chunk and unbind which relies on SplitToSequence op.
# Split returns a Sequence that should be treats as a single
# value. So we wrap it into a tuple.
# TODO(justinchuby): Find a more general solution
reference_torch_outputs = [reference_torch_outputs]
test_name = test_suite.id()
function_output, model_proto = function_executor(
test_name, reference_torch_outputs
)(onnx_function, input_onnx, kwargs_onnx)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
flattened_function_outputs, _ = pytree.tree_flatten(function_output)
assert flattened_torch_outputs
assert len(flattened_torch_outputs) == len(flattened_function_outputs)
for j, (torch_output, function_output) in enumerate(
zip(flattened_torch_outputs, flattened_function_outputs)
):
actual = torch.tensor(function_output)
expected = (
torch_output
if isinstance(torch_output, torch.Tensor)
else torch.tensor(torch_output)
)
if (
op.name in ops_test_data.NONDETERMINISTIC_OPS
or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
):
# Check shape and dtype only for ops that are known to be
# nondeterministic
test_suite.assertEqual(actual.shape, expected.shape)
test_suite.assertEqual(actual.dtype, expected.dtype)
continue
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
try:
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=True,
check_device=False,
)
except AssertionError as e:
if (
os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
and test_behavior is None
):
error_reproduction.create_mismatch_report(
test_name,
i,
model_proto,
inputs,
cpu_sample.kwargs,
actual,
expected,
e,
__file__,
)
if len(flattened_torch_outputs) > 1:
raise AssertionError(f"Output {j} mismatch") from e
raise
class TestOutputConsistencyFullGraph(common_utils.TestCase):
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
This is a parameterized test suite.
"""
def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)
ort.set_seed(42)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.TESTED_OPS
],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
# Base test method for testing each op by running the full ONNX graph.
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.TORCHLIB_OPINFO_MAPPING,
)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_complex_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
],
allowed_dtypes=COMPLEX_TYPES,
)
def test_complex_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
"""Base test method for testing each op by running the full ONNX graph."""
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.COMPLEX_FUNCTION_MAPPING,
)
common_device_type.instantiate_device_type_tests(
TestOutputConsistencyFullGraph, globals(), only_for=["cpu"]
)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -0,0 +1,78 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Typings for function definitions."""
from __future__ import annotations
from typing import TypeVar, Union
from onnxscript import (
BFLOAT16,
BOOL,
COMPLEX128,
COMPLEX64,
DOUBLE,
FLOAT,
FLOAT16,
INT16,
INT32,
INT64,
INT8,
STRING,
UINT8,
)
# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
# More detail can be found: https://pytorch.org/docs/stable/tensors.html
_TensorType = Union[
BFLOAT16,
BOOL,
COMPLEX64,
COMPLEX128,
DOUBLE,
FLOAT,
FLOAT16,
INT8,
INT16,
INT32,
INT64,
UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
FLOAT16,
FLOAT,
DOUBLE,
INT8,
INT16,
INT32,
INT64,
]
TTensor = TypeVar("TTensor", bound=_TensorType)
# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
# but do not constrain the type to be the same as the other inputs/outputs
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrUInt8 = TypeVar(
"TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]
)
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
TRealUnlessInt16OrInt8 = TypeVar(
"TRealUnlessInt16OrInt8",
bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64],
)
TRealUnlessFloat16OrInt8 = TypeVar(
"TRealUnlessFloat16OrInt8", bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64]
)
TRealOrUInt8 = TypeVar("TRealOrUInt8", bound=Union[RealType, UINT8])
TFloatHighPrecision = TypeVar("TFloatHighPrecision", bound=Union[FLOAT, DOUBLE])

View File

@ -1,6 +1,6 @@
from __future__ import annotations
__all__ = ["hop"]
__all__ = ["core", "hop"]
from torch.onnx._internal.exporter._torchlib.ops import hop
from torch.onnx._internal.exporter._torchlib.ops import core, hop

View File

@ -0,0 +1,47 @@
"""torch.ops.aten operators under the `core` module."""
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
# ruff: noqa: TCH001,TCH002
# flake8: noqa
from __future__ import annotations
import operator
from onnxscript.onnx_opset import opset18 as op
import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal, TRealOrUInt8
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
aten = torch.ops.aten
@onnx_impl((aten.abs.default, operator.abs), trace_only=True)
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
return op.Abs(self)
@onnx_impl(aten.abs.default, complex=True, trace_only=True)
def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
return op.ReduceL2(self, [-1], keepdims=False)
@onnx_impl((aten.add.Tensor, aten.add.Scalar, operator.add), trace_only=True)
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
return op.Add(self, other)
@onnx_impl((aten.add.Tensor, aten.add.Scalar), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
return aten_add(self, other, alpha=alpha)