pytorch/test/onnx/torchlib/test_ops.py
shubhambhokare1 1a56609e75 [ONNX] Supporting different opset versions for torchlib registry (#149901)
- Allows opset_version to determine which onnx decomposition to choose
- Adds a cleanup function to modify the registry after it is built

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149901
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
2025-04-09 16:03:46 +00:00

358 lines
13 KiB
Python

# Owner(s): ["module: onnx"]
"""Test op correctness by comparing with PyTorch results.
Usage:
pytest test_ops.py
To run tests on a specific operator (e.g. torch.ceil):
pytest test_ops.py -k ceil
To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
pytest test_ops.py -k nn_functional_scaled_dot_product_attention
## Environment variables
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
errors.
"""
from __future__ import annotations
import os
from typing import Callable, Optional, TYPE_CHECKING
import error_reproduction
import numpy as np
import onnx
import onnxruntime as ort
import onnxscript
import ops_test_common
import ops_test_data
import parameterized
import torch
from torch.testing._internal import common_device_type, common_utils
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
import unittest
from collections.abc import Sequence
from torch.testing._internal.opinfo import core as opinfo_core
# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
TESTED_DTYPES = (
torch.float16,
torch.float32,
# Uncomment below item when we really need testing it
# torch.bfloat16,
# torch.float64,
torch.bool,
# torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
)
# NOTE: torch.complex32 is experimental in torch
COMPLEX_TYPES = (torch.complex64,)
def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
"""Returns all dtypes except the ones specified."""
return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)
def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype, device_type: str
) -> tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None
for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if not decorator_meta.enabled_if:
# Do not skip the test if the decorator meta is not enabled
continue
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if (
decorator_meta.device_type is not None
and decorator_meta.device_type != device_type
):
# Not applicable for this device_type
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
class TestFunctionValidity(common_utils.TestCase):
@parameterized.parameterized.expand(
[
(info.op.name, info)
for info in ops_test_data.TESTED_TORCHLIB_OPS
if isinstance(info.op, onnxscript.OnnxFunction)
],
skip_on_empty=True,
)
def test_script_function_passes_checker(
self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
):
function_proto = torchlib_op_info.op.to_function_proto()
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
def run_test_output_match(
test_suite: unittest.TestCase,
device: str,
dtype: torch.dtype,
op: opinfo_core.OpInfo,
function_executor: Callable,
tested_op_mapping: dict[
str,
ops_test_data.TorchLibOpInfo,
],
):
"""Base test method for testing each opset, used by instantiate_device_type_tests.
Args:
test_suite: The test class instance.
device: The PyTorch device. instantiate_device_type_tests provides this.
dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
op: The OpInfo instance. instantiate_device_type_tests provides this.
function_executor: The function executor. This is a function that takes
a function and its arguments and returns the output of the function.
tested_op_mapping: The mapping of op name to the tested op.
"""
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
torchlib_op_info = tested_op_mapping[op.name]
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function = torchlib_op_info.op
input_wrangler = torchlib_op_info.input_wrangler
if (
not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
and dtype not in COMPLEX_TYPES
):
test_suite.skipTest(
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
f"Type constraints: {onnx_function.op_schema.type_constraints}"
)
# Obtain the tolerance for the op
rtol, atol = torchlib_op_info.get_tolerance(dtype)
for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with test_suite.subTest(
sample_num=i,
inputs=repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
if isinstance(inp, torch.Tensor)
else inp
for inp in inputs
]
),
kwargs=repr(cpu_sample.kwargs),
):
try:
device_type = cpu_sample.args[0].device.type
except (AttributeError, IndexError):
device_type = "cpu"
test_behavior, reason = _should_skip_xfail_test_sample(
op.name, cpu_sample, dtype, device_type
)
with ops_test_common.normal_xfail_skip_test_behaviors(
test_behavior, reason
):
input_onnx = [
ops_test_common.convert_tensor_to_numpy(x) for x in inputs
]
kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
if input_wrangler:
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)
if isinstance(torch_output, torch.Tensor) and torch.is_complex(
torch_output
):
torch_output = torch.view_as_real(torch_output.resolve_conj())
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
if (
op.name.startswith("split")
or op.name.startswith("chunk")
or op.name.startswith("unbind")
or op.name
in {
"atleast_1d_Sequence",
"atleast_2d_Sequence",
"atleast_3d_Sequence",
}
):
# Hack for handling split, chunk and unbind which relies on SplitToSequence op.
# Split returns a Sequence that should be treats as a single
# value. So we wrap it into a tuple.
# TODO(justinchuby): Find a more general solution
reference_torch_outputs = [reference_torch_outputs]
test_name = test_suite.id()
function_output, model_proto = function_executor(
test_name,
reference_torch_outputs,
opset_version=torchlib_op_info.opset_introduced,
)(onnx_function, input_onnx, kwargs_onnx)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
flattened_function_outputs, _ = pytree.tree_flatten(function_output)
assert flattened_torch_outputs
assert len(flattened_torch_outputs) == len(flattened_function_outputs)
for j, (torch_output, function_output) in enumerate(
zip(flattened_torch_outputs, flattened_function_outputs)
):
actual = torch.tensor(function_output)
expected = (
torch_output
if isinstance(torch_output, torch.Tensor)
else torch.tensor(torch_output)
)
if (
op.name in ops_test_data.NONDETERMINISTIC_OPS
or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
):
# Check shape and dtype only for ops that are known to be
# nondeterministic
test_suite.assertEqual(actual.shape, expected.shape)
test_suite.assertEqual(actual.dtype, expected.dtype)
continue
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
try:
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=True,
check_device=False,
)
except AssertionError as e:
if (
os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
and test_behavior is None
):
error_reproduction.create_mismatch_report(
test_name,
i,
model_proto,
inputs,
cpu_sample.kwargs,
actual,
expected,
e,
__file__,
)
if len(flattened_torch_outputs) > 1:
raise AssertionError(f"Output {j} mismatch") from e
raise
class TestOutputConsistencyFullGraph(common_utils.TestCase):
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
This is a parameterized test suite.
"""
def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)
ort.set_seed(42)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.TESTED_OPS
],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
# Base test method for testing each op by running the full ONNX graph.
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.TORCHLIB_OPINFO_MAPPING,
)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_complex_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
],
allowed_dtypes=COMPLEX_TYPES,
)
def test_complex_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
"""Base test method for testing each op by running the full ONNX graph."""
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.COMPLEX_FUNCTION_MAPPING,
)
common_device_type.instantiate_device_type_tests(
TestOutputConsistencyFullGraph, globals(), only_for=["cpu"]
)
if __name__ == "__main__":
common_utils.run_tests()