pytorch/test/onnx/torchlib/ops_test_common.py
Justin Chu 41ae15faa3 [ONNX] Add scaffolding for onnx decomp and logic for op tests (#147392)
Create scaffold for onnx op test data and common logic. This PR creates the scaffolding for new onnx decomp functions described in https://github.com/pytorch/pytorch/issues/139301. It adds two ops: abs and add, and enables the related tests.

https://github.com/pytorch/pytorch/issues/139301
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147392
Approved by: https://github.com/titaiwangms
ghstack dependencies: #147396
2025-02-19 21:55:12 +00:00

702 lines
25 KiB
Python

# Owner(s): ["module: onnx"]
"""Common utils for testing operators."""
from __future__ import annotations
import contextlib
import copy
import dataclasses
import multiprocessing
import os
import pprint
import sys
import unittest
import warnings
from typing import (
Any,
Callable,
Collection,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
)
import error_reproduction
import numpy as np
import onnx
import onnxruntime as ort
import onnxruntime.capi.onnxruntime_pybind11_state
import onnxscript
import onnxscript.evaluator
import pytest
from onnxscript import ir
import torch
from torch.onnx._internal.exporter import _building, _ir_passes, _tensors
from torch.testing._internal.opinfo import core as opinfo_core
T = TypeVar("T")
# Convenience tuples for creating dtype lists when skipping or xfailing tests
BOOL_TYPES = (torch.bool,)
INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
FLOAT_TYPES = (
torch.float16,
torch.float32,
torch.float64,
)
TEST_OPSET_VERSION = 18
IS_MACOS = sys.platform.startswith("darwin")
IS_WINDOWS = os.name == "nt"
@dataclasses.dataclass
class DecorateMeta:
"""A dataclass for storing information about a test case to skip or xfail.
Adapted from functorch: functorch/test/common_utils.py
"""
op_name: str
variant_name: str
decorator: Callable[..., Any]
dtypes: Optional[Collection[torch.dtype]]
device_type: Optional[str]
reason: str
test_behavior: str
matcher: Optional[Callable[[Any], bool]] = None
enabled_if: bool = True
# The test_class_name to apply the decorator to. If None, the decorator is
# applied to all test classes.
test_class_name: Optional[str] = None
def xfail(
op_name: str,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> DecorateMeta:
"""Expects an OpInfo test to fail.
Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
reason: The reason for the failure.
dtypes: The dtypes to expect the failure.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the xfail is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the xfail is enabled.
test_class_name: The test class name to apply the xfail to. If None, the
xfail is applied to all test classes.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
reason=reason,
enabled_if=enabled_if,
test_class_name=test_class_name,
test_behavior="xfail",
)
def skip(
op_name: str,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> DecorateMeta:
"""Skips an OpInfo test.
Args:
op_name: The name of the operator.
variant_name: Optional OpInfo variant_test_name.
reason: The reason for skipping.
dtypes: The dtypes to skip.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the skip is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the skip is enabled.
test_class_name: The test class name to apply the skip to. If None, the skip
is applied to all test classes.
"""
return DecorateMeta(
op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
dtypes=dtypes,
device_type=device_type,
reason=reason,
matcher=matcher,
enabled_if=enabled_if,
test_class_name=test_class_name,
test_behavior="skip",
)
def add_decorate_info(
all_opinfos: Sequence[opinfo_core.OpInfo],
test_class_name: str,
base_test_name: str,
skip_or_xfails: Iterable[DecorateMeta],
) -> Callable[[T], T]:
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list."""
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
for decorate_meta in skip_or_xfails:
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
if opinfo is None and not decorate_meta.enabled_if:
# If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo
# because it could be an OpInfo that is in torch-nightly but not older versions.
continue
assert (
opinfo is not None
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
decorators = list(opinfo.decorators)
new_decorator = opinfo_core.DecorateInfo(
decorate_meta.decorator,
decorate_meta.test_class_name or test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
device_type=decorate_meta.device_type,
active_if=decorate_meta.enabled_if,
)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def duplicate_opinfo(
opinfos: list[opinfo_core.OpInfo], name: str, new_names: tuple[str, ...]
):
"""Duplicate an opinfo in the opinfo database and give it a new name."""
duplicated = []
all_info_names = {opinfo.name for opinfo in opinfos}
for opinfo in opinfos:
if opinfo.name == name:
for new_name in new_names:
if new_name in all_info_names:
# NOTE: Avoid duplicating an opinfo that already exists in the database.
# New opinfos are expected to be added in torch-nightly.
warnings.warn(
f"OpInfo {new_name} already exists in the database.",
stacklevel=1,
)
continue
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
duplicated.append(new_opinfo)
opinfos.extend(duplicated)
def duplicate_opinfo_for_prims(
opinfos: list[opinfo_core.OpInfo], name: str, prims_name: str | None = None
):
"""Duplicate an opinfo in the opinfo database for a prims op.
The function sets the new OpInfo to use the variation torch.ops.prims.
The new OpInfo will have the name "prims_{prims_name}" where `prims_name` is the
name of the prims op. If `prims_name` is None, it will be set to "prims_{name}".
Args:
opinfos: The list of opinfo_core.OpInfo to add the new opinfo to.
name: The name of the opinfo to duplicate.
prims_name: The name of the prims op. If None, it will be set to `name`.
"""
if prims_name is None:
prims_name = name
# The name of the new OpInfo
new_name = f"prims_{prims_name}"
all_info_names = {opinfo.name for opinfo in opinfos}
for opinfo in opinfos:
if opinfo.name == name:
if new_name in all_info_names:
# NOTE: Avoid duplicating an opinfo that already exists in the database.
warnings.warn(
f"OpInfo {new_name} already exists in the database.", stacklevel=1
)
continue
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
new_opinfo.op = getattr(torch.ops.prims, prims_name)
opinfos.append(new_opinfo)
return
raise RuntimeError(f"OpInfo '{name}' not found in the database.")
TORCH_TYPE_TO_ONNX = {
torch.bool: onnx.TensorProto.BOOL,
torch.uint8: onnx.TensorProto.UINT8,
torch.int8: onnx.TensorProto.INT8,
torch.int16: onnx.TensorProto.INT16,
torch.int32: onnx.TensorProto.INT32,
torch.int64: onnx.TensorProto.INT64,
torch.float16: onnx.TensorProto.FLOAT16,
torch.float32: onnx.TensorProto.FLOAT,
torch.float64: onnx.TensorProto.DOUBLE,
torch.complex64: onnx.TensorProto.COMPLEX64,
torch.complex128: onnx.TensorProto.COMPLEX128,
torch.bfloat16: onnx.TensorProto.BFLOAT16,
}
def convert_tensor_to_numpy(input: Any) -> Any:
if isinstance(input, torch.Tensor):
if torch.is_complex(input):
# from complex to real representation
input = torch.view_as_real(input)
return input.detach().cpu().numpy()
if isinstance(input, complex):
return torch.view_as_real(torch.tensor(input)).detach().cpu().numpy()
if isinstance(input, list):
if len(input) == 0:
return np.array((), dtype=np.int64)
if any(isinstance(x, torch.Tensor) for x in input):
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
return [convert_tensor_to_numpy(x) for x in input]
if isinstance(input[0], bool):
return np.array(input, dtype=np.bool_)
# Just a sequence of numbers
if isinstance(input[0], int):
return np.array(input, dtype=np.int64)
if isinstance(input[0], float):
return np.array(input)
return input
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Converts kwargs to be compatible with ONNX Runtime."""
new_kwargs = {}
for key, value in kwargs.items():
if key == "device":
continue
if key == "dtype":
value = TORCH_TYPE_TO_ONNX[value]
if isinstance(value, torch.Tensor):
value = np.array(value.cpu())
new_kwargs[key] = value
return new_kwargs
class OrtAbortedError(RuntimeError):
"""ONNX Runtime Aborted."""
def _ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]):
"""Run a model with ONNX Runtime."""
# Disable all ORT optimizations
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
)
session = ort.InferenceSession(
serialized_model, session_options, providers=("CPUExecutionProvider",)
)
return session.run(None, ort_inputs)
def _ort_session_run_return_dict(
serialized_model: bytes, ort_inputs: Mapping[str, Any], return_dict
) -> None:
"""Run a model with ONNX Runtime and store the results in return_dict."""
try:
return_dict["results"] = _ort_session_run(serialized_model, ort_inputs)
return_dict["error"] = None
except Exception as e: # pylint: disable=broad-except
return_dict["results"] = None
return_dict["error"] = e
def _safe_ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]):
"""Run a model with ONNX Runtime in a separate process.
Args:
serialized_model: Serialized ONNX model proto.
ort_inputs: Inputs to the model.
Returns:
The inference result.
Raises:
OrtAbortedError if the process did not execute successfully.
"""
manager = multiprocessing.Manager()
return_dict = manager.dict()
process = multiprocessing.Process(
target=_ort_session_run_return_dict,
args=(serialized_model, ort_inputs, return_dict),
)
process.start()
process.join()
process.close()
if not return_dict:
raise OrtAbortedError
if return_dict["error"] is not None:
raise return_dict["error"]
return return_dict["results"]
def _format_model_and_input_information(onnx_model, inputs):
return (
f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}"
)
_TORCH_DTYPE_TO_ONNX_STRING = {
torch.bool: "tensor(bool)",
torch.uint8: "tensor(uint8)",
torch.int8: "tensor(int8)",
torch.int16: "tensor(int16)",
torch.int32: "tensor(int32)",
torch.int64: "tensor(int64)",
torch.float16: "tensor(float16)",
torch.float32: "tensor(float)",
torch.float64: "tensor(double)",
torch.complex64: "tensor(complex64)",
torch.complex128: "tensor(complex128)",
torch.bfloat16: "tensor(bfloat16)",
}
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.COMPLEX128,
torch.complex64: ir.DataType.COMPLEX64,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
torch.uint16: ir.DataType.UINT16,
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
"""Checks if the dtype is compatible with the schema.
When a dtype is "compatible" with the schema, it means we can use the dtype
to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.
Args:
dtype: The torch dtype used to create sample inputs by OpInfo.
schema: The ONNX schema of the function.
Returns:
True if the dtype is compatible with the schema.
"""
if not schema.inputs:
# If there are no inputs, we can't check compatibility. Assume it is compatible.
# e.g. aten_randn has only attributes.
return True
if schema.inputs[0].name not in {"self", "input"}:
# If the name of the first input is not "self" or "input",
# it is usually an input that is not of the same type as the output.
# We assume support in this case.
#
# For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)`
# has the first input as `size`, which is an integer, but it can support
# any dtype.
return True
# Otherwise we check the type constraints of the first input.
# For example, when dtype=torch.float32, and the op being tested has the schema
# ```
# OpSchema(
# name='aten_abs',
# domain='pkg.onnxscript.torch_lib',
# since_version=1,
# doc='abs(Tensor self) -> Tensor',
# type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal',
# allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)',
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
# 'tensor(bfloat16)'], description='')],
# inputs=[OpSchema.FormalParameter(name='self', type_str='TReal',
# description='', param_option=<FormalParameterOption.Single: 0>,
# is_homogeneous=True, min_arity=1,
# differentiation_category=<DifferentiationCategory.Unknown: 0>)],
# outputs=[OpSchema.FormalParameter(name='return_val',
# type_str='TReal', description='',
# param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True,
# min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
# attributes={}
# )
# ```
# we see the first input type is "TReal", corresponding to the type constraint
# with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)',
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
# 'tensor(bfloat16)'].
# Since torch.float32 (tensor(float)) is in the allowed types, we return True.
first_input_type_name = schema.inputs[0].type_str
# Find the type constraint for the first input by matching the parameter name
first_input_type_constraint = next(
(
x
for x in schema.type_constraints
if first_input_type_name in x.type_param_str
),
None,
)
assert first_input_type_constraint is not None
allowed_type_strs = first_input_type_constraint.allowed_type_strs
# Here we consider seq(tensor(float)) compatible with tensor(float) as well
return any(
_TORCH_DTYPE_TO_ONNX_STRING[dtype] in type_str for type_str in allowed_type_strs
)
def graph_executor(
test_name: str,
outputs: Sequence[Any],
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
"""Eagerly executes a function."""
def _capture_graph_and_evaluate_torch_script_evaluator(
function: Callable, args, kwargs
) -> tuple[Any, onnx.ModelProto]:
"""Captures the graph of a function and evaluates it using TorchScriptEvaluator."""
# Initialize the ONNX graph
graph = ir.Graph(
(),
(),
nodes=(),
opset_imports={"": 18, "pkg.torch.onnx": 1},
name="main_graph",
)
opset = onnxscript.opset18
tracer = _building.OpRecorder(opset, {})
ort_inputs = {}
onnxscript_args: list[Any] = []
onnxscript_kwargs = {}
for i, arg in enumerate(args):
if isinstance(arg, np.ndarray):
input_name = f"input_{i}"
input = _tensors.SymbolicTensor(
opset=opset,
name=input_name,
shape=ir.Shape(arg.shape),
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(arg).dtype]),
)
graph.inputs.append(input)
onnxscript_args.append(input)
ort_inputs[input_name] = arg
elif isinstance(arg, (list, tuple)):
# str is also a sequence but we do not want to treat it as a tensor
sequence_input = []
for j, subarg in enumerate(arg):
if isinstance(subarg, np.ndarray):
input_name = f"input_{i}_{j}"
tensor = torch.tensor(subarg)
input = _tensors.SymbolicTensor(
opset=opset,
name=input_name,
shape=ir.Shape(tensor.shape),
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[tensor.dtype]),
)
graph.inputs.append(input)
sequence_input.append(input)
ort_inputs[input_name] = subarg
else:
# Include non-numpy inputs as-is
# For example, it could be a None value that we want to keep
sequence_input.append(subarg)
onnxscript_args.append(sequence_input)
else:
onnxscript_args.append(arg)
for key, value in kwargs.items():
if isinstance(value, np.ndarray):
input = _tensors.SymbolicTensor(
opset=opset,
name=key,
shape=ir.Shape(torch.tensor(value).shape),
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(value).dtype]),
)
graph.inputs.append(input)
ort_inputs[key] = value
onnxscript_kwargs[key] = input
else:
onnxscript_kwargs[key] = value
with onnxscript.evaluator.default_as(tracer):
symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs)
if not isinstance(symbolic_outputs, Sequence):
symbolic_outputs = (symbolic_outputs,)
# We need to set the size of the output tensors for the ONNX model to be valid
for output, symbolic_output in zip(outputs, symbolic_outputs):
if isinstance(output, Sequence):
# Output is a sequence
elem_dtype = _TORCH_DTYPE_TO_ONNX[output[0].dtype]
symbolic_output.type = ir.SequenceType(ir.TensorType(elem_dtype))
continue
output = (
output
if isinstance(output, torch.Tensor)
else torch.tensor(output, device="cpu")
)
symbolic_output.shape = ir.Shape(output.shape)
symbolic_output.dtype = _TORCH_DTYPE_TO_ONNX[output.dtype]
graph.outputs.extend(symbolic_outputs)
graph.extend(tracer.nodes)
onnx_model = ir.Model(graph, ir_version=10, producer_name="torch_test")
for identifier, onnxscript_function in tracer.functions.items():
if identifier in onnx_model.functions:
continue
if isinstance(onnxscript_function, ir.Function):
ir_function = onnxscript_function
else:
# TODO: Get IR function directly when onnxscript is updated
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
_ir_passes.add_torchlib_common_imports(onnx_model)
_ir_passes.add_opset_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)
try:
onnx.checker.check_model(model_proto, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
model_proto = onnx.shape_inference.infer_shapes(model_proto, data_prop=True)
try:
if (
os.environ.get("CATCH_ORT_SEGFAULT") == "1"
or os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
):
# Use an individual process to run ONNX Runtime to catch segfaults
return _safe_ort_session_run(
model_proto.SerializeToString(), ort_inputs
), model_proto
return _ort_session_run(
model_proto.SerializeToString(), ort_inputs
), model_proto
except (
# pylint: disable=c-extension-no-member
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented,
# pylint: enable=c-extension-no-member
) as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_reproduction_report(
test_name,
model_proto,
ort_inputs,
e,
"test/onnx/torchlib/test_ops.py",
)
raise RuntimeError(
"ONNX Runtime failed to evaluate:\n"
+ _format_model_and_input_information(model_proto, ort_inputs)
) from e
except OrtAbortedError as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
# Save the model and inputs to a file for reproduction
error_reproduction.create_reproduction_report(
test_name,
model_proto,
ort_inputs,
e,
"test/onnx/torchlib/test_ops.py",
)
raise OrtAbortedError(
"ONNX Runtime aborted:\n"
+ _format_model_and_input_information(model_proto, ort_inputs)
) from e
except Exception as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_reproduction_report(
test_name,
model_proto,
ort_inputs,
e,
"test/onnx/torchlib/test_ops.py",
)
raise
return _capture_graph_and_evaluate_torch_script_evaluator
@contextlib.contextmanager
def normal_xfail_skip_test_behaviors(
test_behavior: Optional[str] = None, reason: Optional[str] = None
):
"""This context manager is used to handle the different behaviors of xfail and skip.
Args:
test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None.
reason (optional[str]): The reason for the failure or skip.
Raises:
e: Any exception raised by the test case if it's not an expected failure.
"""
# We need to skip as soon as possible, as SegFault might also be a case.
if test_behavior == "skip":
pytest.skip(reason=reason)
try:
yield
# We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
# to go over all test cases to find the right exception type.
except Exception: # pylint: disable=broad-exception-caught
if test_behavior is None:
raise
if test_behavior == "xfail":
pytest.xfail(reason=reason)
else:
if test_behavior == "xfail":
pytest.fail("Test unexpectedly passed")