mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947 Approved by: https://github.com/ezyang
358 lines
13 KiB
Python
358 lines
13 KiB
Python
# Owner(s): ["module: onnx"]
|
|
"""Test op correctness by comparing with PyTorch results.
|
|
|
|
Usage:
|
|
|
|
pytest test_ops.py
|
|
|
|
To run tests on a specific operator (e.g. torch.ceil):
|
|
|
|
pytest test_ops.py -k ceil
|
|
|
|
To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
|
|
|
|
pytest test_ops.py -k nn_functional_scaled_dot_product_attention
|
|
|
|
## Environment variables
|
|
|
|
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
|
|
in onnxruntime by running the inference sessions in a separate process.
|
|
|
|
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
|
|
errors.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Optional, TYPE_CHECKING
|
|
|
|
import error_reproduction
|
|
import numpy as np
|
|
|
|
import onnx
|
|
import onnxruntime as ort
|
|
import onnxscript
|
|
import ops_test_common
|
|
import ops_test_data
|
|
import parameterized
|
|
|
|
import torch
|
|
from torch.testing._internal import common_device_type, common_utils
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import unittest
|
|
from collections.abc import Callable, Sequence
|
|
|
|
from torch.testing._internal.opinfo import core as opinfo_core
|
|
|
|
# All dtypes will be tested on the generated symbolic functions.
|
|
# complex64 will be flattened to float32.
|
|
TESTED_DTYPES = (
|
|
torch.float16,
|
|
torch.float32,
|
|
# Uncomment below item when we really need testing it
|
|
# torch.bfloat16,
|
|
# torch.float64,
|
|
torch.bool,
|
|
# torch.int8,
|
|
# torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
# torch.uint8,
|
|
)
|
|
# NOTE: torch.complex32 is experimental in torch
|
|
COMPLEX_TYPES = (torch.complex64,)
|
|
|
|
|
|
def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
|
|
"""Returns all dtypes except the ones specified."""
|
|
return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)
|
|
|
|
|
|
def _should_skip_xfail_test_sample(
|
|
op_name: str, sample, dtype: torch.dtype, device_type: str
|
|
) -> tuple[Optional[str], Optional[str]]:
|
|
"""Returns a reason if a test sample should be skipped."""
|
|
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
|
|
return None, None
|
|
for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
|
|
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
|
|
if decorator_meta.op_name == op_name:
|
|
assert decorator_meta.matcher is not None, "Matcher must be defined"
|
|
if not decorator_meta.enabled_if:
|
|
# Do not skip the test if the decorator meta is not enabled
|
|
continue
|
|
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
|
|
# Not applicable for this dtype
|
|
continue
|
|
if (
|
|
decorator_meta.device_type is not None
|
|
and decorator_meta.device_type != device_type
|
|
):
|
|
# Not applicable for this device_type
|
|
continue
|
|
if decorator_meta.matcher(sample):
|
|
return decorator_meta.test_behavior, decorator_meta.reason
|
|
return None, None
|
|
|
|
|
|
class TestFunctionValidity(common_utils.TestCase):
|
|
@parameterized.parameterized.expand(
|
|
[
|
|
(info.op.name, info)
|
|
for info in ops_test_data.TESTED_TORCHLIB_OPS
|
|
if isinstance(info.op, onnxscript.OnnxFunction)
|
|
],
|
|
skip_on_empty=True,
|
|
)
|
|
def test_script_function_passes_checker(
|
|
self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
|
|
):
|
|
function_proto = torchlib_op_info.op.to_function_proto()
|
|
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
|
|
|
|
|
|
def run_test_output_match(
|
|
test_suite: unittest.TestCase,
|
|
device: str,
|
|
dtype: torch.dtype,
|
|
op: opinfo_core.OpInfo,
|
|
function_executor: Callable,
|
|
tested_op_mapping: dict[
|
|
str,
|
|
ops_test_data.TorchLibOpInfo,
|
|
],
|
|
):
|
|
"""Base test method for testing each opset, used by instantiate_device_type_tests.
|
|
|
|
Args:
|
|
test_suite: The test class instance.
|
|
device: The PyTorch device. instantiate_device_type_tests provides this.
|
|
dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
|
|
op: The OpInfo instance. instantiate_device_type_tests provides this.
|
|
function_executor: The function executor. This is a function that takes
|
|
a function and its arguments and returns the output of the function.
|
|
tested_op_mapping: The mapping of op name to the tested op.
|
|
"""
|
|
samples = op.sample_inputs(
|
|
device,
|
|
dtype,
|
|
requires_grad=False,
|
|
)
|
|
|
|
torchlib_op_info = tested_op_mapping[op.name]
|
|
# Obtain the input_wrangler that manipulates the OpInfo inputs
|
|
# to match the aten operator signature
|
|
# An example is nn.functional.upsample_nearest2d, which has a different signature
|
|
# than the aten operator upsample_nearest2d
|
|
onnx_function = torchlib_op_info.op
|
|
input_wrangler = torchlib_op_info.input_wrangler
|
|
if (
|
|
not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
|
|
and dtype not in COMPLEX_TYPES
|
|
):
|
|
test_suite.skipTest(
|
|
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
|
|
f"Type constraints: {onnx_function.op_schema.type_constraints}"
|
|
)
|
|
|
|
# Obtain the tolerance for the op
|
|
rtol, atol = torchlib_op_info.get_tolerance(dtype)
|
|
for i, cpu_sample in enumerate(samples):
|
|
inputs = (cpu_sample.input, *cpu_sample.args)
|
|
# Provide the repr to subtest because tensors are not serializable in parallel test runs
|
|
with test_suite.subTest(
|
|
sample_num=i,
|
|
inputs=repr(
|
|
[
|
|
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
|
|
if isinstance(inp, torch.Tensor)
|
|
else inp
|
|
for inp in inputs
|
|
]
|
|
),
|
|
kwargs=repr(cpu_sample.kwargs),
|
|
):
|
|
try:
|
|
device_type = cpu_sample.args[0].device.type
|
|
except (AttributeError, IndexError):
|
|
device_type = "cpu"
|
|
test_behavior, reason = _should_skip_xfail_test_sample(
|
|
op.name, cpu_sample, dtype, device_type
|
|
)
|
|
|
|
with ops_test_common.normal_xfail_skip_test_behaviors(
|
|
test_behavior, reason
|
|
):
|
|
input_onnx = [
|
|
ops_test_common.convert_tensor_to_numpy(x) for x in inputs
|
|
]
|
|
kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
|
|
if input_wrangler:
|
|
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
|
|
torch_output = op(*inputs, **cpu_sample.kwargs)
|
|
|
|
if isinstance(torch_output, torch.Tensor) and torch.is_complex(
|
|
torch_output
|
|
):
|
|
torch_output = torch.view_as_real(torch_output.resolve_conj())
|
|
|
|
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
|
|
if (
|
|
op.name.startswith("split")
|
|
or op.name.startswith("chunk")
|
|
or op.name.startswith("unbind")
|
|
or op.name
|
|
in {
|
|
"atleast_1d_Sequence",
|
|
"atleast_2d_Sequence",
|
|
"atleast_3d_Sequence",
|
|
}
|
|
):
|
|
# Hack for handling split, chunk and unbind which relies on SplitToSequence op.
|
|
# Split returns a Sequence that should be treats as a single
|
|
# value. So we wrap it into a tuple.
|
|
# TODO(justinchuby): Find a more general solution
|
|
reference_torch_outputs = [reference_torch_outputs]
|
|
|
|
test_name = test_suite.id()
|
|
function_output, model_proto = function_executor(
|
|
test_name,
|
|
reference_torch_outputs,
|
|
opset_version=torchlib_op_info.opset_introduced,
|
|
)(onnx_function, input_onnx, kwargs_onnx)
|
|
# Finally we re-flatten everything
|
|
# TODO: add pytree structure comparison.
|
|
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
|
|
flattened_function_outputs, _ = pytree.tree_flatten(function_output)
|
|
|
|
assert flattened_torch_outputs
|
|
assert len(flattened_torch_outputs) == len(flattened_function_outputs)
|
|
|
|
for j, (torch_output, function_output) in enumerate(
|
|
zip(flattened_torch_outputs, flattened_function_outputs)
|
|
):
|
|
actual = torch.tensor(function_output)
|
|
expected = (
|
|
torch_output
|
|
if isinstance(torch_output, torch.Tensor)
|
|
else torch.tensor(torch_output)
|
|
)
|
|
|
|
if (
|
|
op.name in ops_test_data.NONDETERMINISTIC_OPS
|
|
or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
|
|
):
|
|
# Check shape and dtype only for ops that are known to be
|
|
# nondeterministic
|
|
test_suite.assertEqual(actual.shape, expected.shape)
|
|
test_suite.assertEqual(actual.dtype, expected.dtype)
|
|
continue
|
|
|
|
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
|
|
try:
|
|
torch.testing.assert_close(
|
|
actual,
|
|
expected,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
equal_nan=True,
|
|
check_device=False,
|
|
)
|
|
except AssertionError as e:
|
|
if (
|
|
os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
|
|
and test_behavior is None
|
|
):
|
|
error_reproduction.create_mismatch_report(
|
|
test_name,
|
|
i,
|
|
model_proto,
|
|
inputs,
|
|
cpu_sample.kwargs,
|
|
actual,
|
|
expected,
|
|
e,
|
|
__file__,
|
|
)
|
|
if len(flattened_torch_outputs) > 1:
|
|
raise AssertionError(f"Output {j} mismatch") from e
|
|
raise
|
|
|
|
|
|
class TestOutputConsistencyFullGraph(common_utils.TestCase):
|
|
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
|
|
|
|
This is a parameterized test suite.
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
torch.manual_seed(42)
|
|
np.random.seed(42)
|
|
ort.set_seed(42)
|
|
|
|
@ops_test_common.add_decorate_info(
|
|
ops_test_data.OPS_DB,
|
|
"TestOutputConsistencyFullGraph",
|
|
"test_output_match_opinfo_",
|
|
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
|
|
)
|
|
@common_device_type.ops( # type: ignore[misc]
|
|
[
|
|
info
|
|
for info in ops_test_data.OPS_DB
|
|
if info.name in ops_test_data.TESTED_OPS
|
|
],
|
|
allowed_dtypes=TESTED_DTYPES,
|
|
)
|
|
def test_output_match_opinfo_(
|
|
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
|
|
):
|
|
# Base test method for testing each op by running the full ONNX graph.
|
|
run_test_output_match(
|
|
self,
|
|
device,
|
|
dtype,
|
|
op,
|
|
ops_test_common.graph_executor,
|
|
ops_test_data.TORCHLIB_OPINFO_MAPPING,
|
|
)
|
|
|
|
@ops_test_common.add_decorate_info(
|
|
ops_test_data.OPS_DB,
|
|
"TestOutputConsistencyFullGraph",
|
|
"test_complex_output_match_opinfo_",
|
|
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
|
|
)
|
|
@common_device_type.ops( # type: ignore[misc]
|
|
[
|
|
info
|
|
for info in ops_test_data.OPS_DB
|
|
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
|
|
],
|
|
allowed_dtypes=COMPLEX_TYPES,
|
|
)
|
|
def test_complex_output_match_opinfo_(
|
|
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
|
|
):
|
|
"""Base test method for testing each op by running the full ONNX graph."""
|
|
run_test_output_match(
|
|
self,
|
|
device,
|
|
dtype,
|
|
op,
|
|
ops_test_common.graph_executor,
|
|
ops_test_data.COMPLEX_FUNCTION_MAPPING,
|
|
)
|
|
|
|
|
|
common_device_type.instantiate_device_type_tests(
|
|
TestOutputConsistencyFullGraph, globals(), only_for=["cpu"]
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|