mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add scaffolding for onnx decomp and logic for op tests (#147392)
Create scaffold for onnx op test data and common logic. This PR creates the scaffolding for new onnx decomp functions described in https://github.com/pytorch/pytorch/issues/139301. It adds two ops: abs and add, and enables the related tests. https://github.com/pytorch/pytorch/issues/139301 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147392 Approved by: https://github.com/titaiwangms ghstack dependencies: #147396
This commit is contained in:
parent
24738768a8
commit
41ae15faa3
80
test/onnx/torchlib/README.md
Normal file
80
test/onnx/torchlib/README.md
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# Test op correctness by comparing with PyTorch results using OpInfo
|
||||
|
||||
`OpInfo` is PyTorch's standard mechanism for composing test data for operators.
|
||||
Read more about them on https://github.com/pytorch/pytorch/blob/ce4a097bf769d753712a1fd969b446c59e29d8b9/torch/testing/_internal/opinfo/core.py#L362.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# All
|
||||
python -m pytest test_ops.py
|
||||
|
||||
# To run tests on a specific operator (e.g. torch.ceil):
|
||||
python -m pytest test_ops.py -k ceil
|
||||
|
||||
# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
|
||||
python -m pytest test_ops.py -k nn_functional_scaled_dot_product_attention
|
||||
```
|
||||
|
||||
### Environment variables
|
||||
|
||||
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
|
||||
in onnxruntime by running the inference sessions in a separate process.
|
||||
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g.
|
||||
|
||||
```bash
|
||||
CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/torchlib/test_ops.py -k div_mode_int
|
||||
```
|
||||
|
||||
## How to add a new operator test
|
||||
|
||||
See _usage_ in [`ops_test_data.py`](./ops_test_data.py)
|
||||
|
||||
## How to add custom OpInfo tests
|
||||
|
||||
Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it.
|
||||
|
||||
Follow the steps below to create new OpInfo tests:
|
||||
|
||||
1. Use the implementation for `ops.aten.slice_scatter` as a reference (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2412-L2418) to declare an OpInfo in [`extra_opinfo.py`](./extra_opinfo.py)
|
||||
|
||||
```py
|
||||
opinfo_core.OpInfo(
|
||||
"ops.aten.slice_scatter",
|
||||
aten_name="slice_scatter",
|
||||
dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool),
|
||||
sample_inputs_func=sample_inputs_slice_scatter,
|
||||
supports_out=False,
|
||||
),
|
||||
```
|
||||
|
||||
- The first argument should be the operator name under the `torch.ops` namespace. For example, if you want to test the `prims.var` op, then put `"ops.prims.var"`. It should almost always start with `ops.`.
|
||||
- Follow existing examples to specify the `dtypes` you want to test the op on.
|
||||
- Specify `op=` if the target operator is not the same as the OpInfo name (first arg). For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2065-L2068.
|
||||
|
||||
```py
|
||||
opinfo_core.OpInfo(
|
||||
"ops.aten.bernoulli.p_deterministic",
|
||||
op=torch.ops.aten.bernoulli.p,
|
||||
```
|
||||
|
||||
The op is `torch.ops.aten.bernoulli.p`, which is different from the name `ops.aten.bernoulli.p_deterministic`. OpInfo names need to be globally unique in a test suite. When `op` is not specified, it will look for the op in `torch.` using its name.
|
||||
|
||||
2. Implement the `sample_inputs_func`. (Ref: https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L1242-L1268)
|
||||
1. Copy the function and decide what the input shapes should be. Use `make_arg` to generate a torch.Tensor. Alternatively you could also use `torch.tensor` to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with
|
||||
|
||||
```py
|
||||
yield opinfo_core.SampleInput(input, args=(...), kwargs={...})
|
||||
```
|
||||
|
||||
`input` is the first arg. The rest of the args are in `args`.
|
||||
3. Enable the test case in [`ops_test_data.py`](./ops_test_data.py)
|
||||
1. Add a `TorchLibOpInfo` entry to the `TESTED_TORCHLIB_OPS` list. (For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L2116)
|
||||
|
||||
```py
|
||||
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter)
|
||||
```
|
||||
|
||||
You can additionally specify dtype tolerance (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L539) or conditional skips (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L586-L590).
|
||||
|
||||
Now that the test is added, you may run the test like mentioned above. Set `CREATE_REPRODUCTION_REPORT=1` to get markdown reports and view failing input combinations should any test case fails.
|
||||
701
test/onnx/torchlib/ops_test_common.py
Normal file
701
test/onnx/torchlib/ops_test_common.py
Normal file
|
|
@ -0,0 +1,701 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
"""Common utils for testing operators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import multiprocessing
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import error_reproduction
|
||||
import numpy as np
|
||||
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import onnxruntime.capi.onnxruntime_pybind11_state
|
||||
import onnxscript
|
||||
import onnxscript.evaluator
|
||||
import pytest
|
||||
from onnxscript import ir
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _building, _ir_passes, _tensors
|
||||
from torch.testing._internal.opinfo import core as opinfo_core
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Convenience tuples for creating dtype lists when skipping or xfailing tests
|
||||
|
||||
BOOL_TYPES = (torch.bool,)
|
||||
|
||||
INT_TYPES = (
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.uint8,
|
||||
)
|
||||
|
||||
FLOAT_TYPES = (
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
)
|
||||
|
||||
TEST_OPSET_VERSION = 18
|
||||
IS_MACOS = sys.platform.startswith("darwin")
|
||||
IS_WINDOWS = os.name == "nt"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DecorateMeta:
|
||||
"""A dataclass for storing information about a test case to skip or xfail.
|
||||
|
||||
Adapted from functorch: functorch/test/common_utils.py
|
||||
"""
|
||||
|
||||
op_name: str
|
||||
variant_name: str
|
||||
decorator: Callable[..., Any]
|
||||
dtypes: Optional[Collection[torch.dtype]]
|
||||
device_type: Optional[str]
|
||||
reason: str
|
||||
test_behavior: str
|
||||
matcher: Optional[Callable[[Any], bool]] = None
|
||||
enabled_if: bool = True
|
||||
# The test_class_name to apply the decorator to. If None, the decorator is
|
||||
# applied to all test classes.
|
||||
test_class_name: Optional[str] = None
|
||||
|
||||
|
||||
def xfail(
|
||||
op_name: str,
|
||||
variant_name: str = "",
|
||||
*,
|
||||
reason: str,
|
||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||
device_type: Optional[str] = None,
|
||||
matcher: Optional[Callable[[Any], Any]] = None,
|
||||
enabled_if: bool = True,
|
||||
test_class_name: Optional[str] = None,
|
||||
) -> DecorateMeta:
|
||||
"""Expects an OpInfo test to fail.
|
||||
|
||||
Args:
|
||||
op_name: The name of the operator.
|
||||
variant_name: Optional OpInfo variant_test_name.
|
||||
reason: The reason for the failure.
|
||||
dtypes: The dtypes to expect the failure.
|
||||
device_type: Device type. E.g. "cpu", "cuda".
|
||||
matcher: A function that matches the test sample input. It is used only when
|
||||
the xfail is in the SKIP_XFAIL_SUBTESTS list.
|
||||
enabled_if: Whether the xfail is enabled.
|
||||
test_class_name: The test class name to apply the xfail to. If None, the
|
||||
xfail is applied to all test classes.
|
||||
"""
|
||||
return DecorateMeta(
|
||||
op_name=op_name,
|
||||
variant_name=variant_name,
|
||||
decorator=unittest.expectedFailure,
|
||||
dtypes=dtypes,
|
||||
device_type=device_type,
|
||||
matcher=matcher,
|
||||
reason=reason,
|
||||
enabled_if=enabled_if,
|
||||
test_class_name=test_class_name,
|
||||
test_behavior="xfail",
|
||||
)
|
||||
|
||||
|
||||
def skip(
|
||||
op_name: str,
|
||||
variant_name: str = "",
|
||||
*,
|
||||
reason: str,
|
||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||
device_type: Optional[str] = None,
|
||||
matcher: Optional[Callable[[Any], Any]] = None,
|
||||
enabled_if: bool = True,
|
||||
test_class_name: Optional[str] = None,
|
||||
) -> DecorateMeta:
|
||||
"""Skips an OpInfo test.
|
||||
|
||||
Args:
|
||||
op_name: The name of the operator.
|
||||
variant_name: Optional OpInfo variant_test_name.
|
||||
reason: The reason for skipping.
|
||||
dtypes: The dtypes to skip.
|
||||
device_type: Device type. E.g. "cpu", "cuda".
|
||||
matcher: A function that matches the test sample input. It is used only when
|
||||
the skip is in the SKIP_XFAIL_SUBTESTS list.
|
||||
enabled_if: Whether the skip is enabled.
|
||||
test_class_name: The test class name to apply the skip to. If None, the skip
|
||||
is applied to all test classes.
|
||||
"""
|
||||
return DecorateMeta(
|
||||
op_name=op_name,
|
||||
variant_name=variant_name,
|
||||
decorator=unittest.skip(f"Skip: {reason}"),
|
||||
dtypes=dtypes,
|
||||
device_type=device_type,
|
||||
reason=reason,
|
||||
matcher=matcher,
|
||||
enabled_if=enabled_if,
|
||||
test_class_name=test_class_name,
|
||||
test_behavior="skip",
|
||||
)
|
||||
|
||||
|
||||
def add_decorate_info(
|
||||
all_opinfos: Sequence[opinfo_core.OpInfo],
|
||||
test_class_name: str,
|
||||
base_test_name: str,
|
||||
skip_or_xfails: Iterable[DecorateMeta],
|
||||
) -> Callable[[T], T]:
|
||||
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list."""
|
||||
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
|
||||
for decorate_meta in skip_or_xfails:
|
||||
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
|
||||
if opinfo is None and not decorate_meta.enabled_if:
|
||||
# If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo
|
||||
# because it could be an OpInfo that is in torch-nightly but not older versions.
|
||||
continue
|
||||
assert (
|
||||
opinfo is not None
|
||||
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
|
||||
decorators = list(opinfo.decorators)
|
||||
new_decorator = opinfo_core.DecorateInfo(
|
||||
decorate_meta.decorator,
|
||||
decorate_meta.test_class_name or test_class_name,
|
||||
base_test_name,
|
||||
dtypes=decorate_meta.dtypes,
|
||||
device_type=decorate_meta.device_type,
|
||||
active_if=decorate_meta.enabled_if,
|
||||
)
|
||||
decorators.append(new_decorator)
|
||||
opinfo.decorators = tuple(decorators)
|
||||
|
||||
# This decorator doesn't modify fn in any way
|
||||
def wrapped(fn):
|
||||
return fn
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def duplicate_opinfo(
|
||||
opinfos: list[opinfo_core.OpInfo], name: str, new_names: tuple[str, ...]
|
||||
):
|
||||
"""Duplicate an opinfo in the opinfo database and give it a new name."""
|
||||
duplicated = []
|
||||
all_info_names = {opinfo.name for opinfo in opinfos}
|
||||
for opinfo in opinfos:
|
||||
if opinfo.name == name:
|
||||
for new_name in new_names:
|
||||
if new_name in all_info_names:
|
||||
# NOTE: Avoid duplicating an opinfo that already exists in the database.
|
||||
# New opinfos are expected to be added in torch-nightly.
|
||||
warnings.warn(
|
||||
f"OpInfo {new_name} already exists in the database.",
|
||||
stacklevel=1,
|
||||
)
|
||||
continue
|
||||
new_opinfo = copy.deepcopy(opinfo)
|
||||
new_opinfo.name = new_name
|
||||
duplicated.append(new_opinfo)
|
||||
opinfos.extend(duplicated)
|
||||
|
||||
|
||||
def duplicate_opinfo_for_prims(
|
||||
opinfos: list[opinfo_core.OpInfo], name: str, prims_name: str | None = None
|
||||
):
|
||||
"""Duplicate an opinfo in the opinfo database for a prims op.
|
||||
|
||||
The function sets the new OpInfo to use the variation torch.ops.prims.
|
||||
The new OpInfo will have the name "prims_{prims_name}" where `prims_name` is the
|
||||
name of the prims op. If `prims_name` is None, it will be set to "prims_{name}".
|
||||
|
||||
Args:
|
||||
opinfos: The list of opinfo_core.OpInfo to add the new opinfo to.
|
||||
name: The name of the opinfo to duplicate.
|
||||
prims_name: The name of the prims op. If None, it will be set to `name`.
|
||||
"""
|
||||
if prims_name is None:
|
||||
prims_name = name
|
||||
# The name of the new OpInfo
|
||||
new_name = f"prims_{prims_name}"
|
||||
all_info_names = {opinfo.name for opinfo in opinfos}
|
||||
for opinfo in opinfos:
|
||||
if opinfo.name == name:
|
||||
if new_name in all_info_names:
|
||||
# NOTE: Avoid duplicating an opinfo that already exists in the database.
|
||||
warnings.warn(
|
||||
f"OpInfo {new_name} already exists in the database.", stacklevel=1
|
||||
)
|
||||
continue
|
||||
new_opinfo = copy.deepcopy(opinfo)
|
||||
new_opinfo.name = new_name
|
||||
new_opinfo.op = getattr(torch.ops.prims, prims_name)
|
||||
opinfos.append(new_opinfo)
|
||||
return
|
||||
raise RuntimeError(f"OpInfo '{name}' not found in the database.")
|
||||
|
||||
|
||||
TORCH_TYPE_TO_ONNX = {
|
||||
torch.bool: onnx.TensorProto.BOOL,
|
||||
torch.uint8: onnx.TensorProto.UINT8,
|
||||
torch.int8: onnx.TensorProto.INT8,
|
||||
torch.int16: onnx.TensorProto.INT16,
|
||||
torch.int32: onnx.TensorProto.INT32,
|
||||
torch.int64: onnx.TensorProto.INT64,
|
||||
torch.float16: onnx.TensorProto.FLOAT16,
|
||||
torch.float32: onnx.TensorProto.FLOAT,
|
||||
torch.float64: onnx.TensorProto.DOUBLE,
|
||||
torch.complex64: onnx.TensorProto.COMPLEX64,
|
||||
torch.complex128: onnx.TensorProto.COMPLEX128,
|
||||
torch.bfloat16: onnx.TensorProto.BFLOAT16,
|
||||
}
|
||||
|
||||
|
||||
def convert_tensor_to_numpy(input: Any) -> Any:
|
||||
if isinstance(input, torch.Tensor):
|
||||
if torch.is_complex(input):
|
||||
# from complex to real representation
|
||||
input = torch.view_as_real(input)
|
||||
return input.detach().cpu().numpy()
|
||||
if isinstance(input, complex):
|
||||
return torch.view_as_real(torch.tensor(input)).detach().cpu().numpy()
|
||||
if isinstance(input, list):
|
||||
if len(input) == 0:
|
||||
return np.array((), dtype=np.int64)
|
||||
if any(isinstance(x, torch.Tensor) for x in input):
|
||||
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
|
||||
return [convert_tensor_to_numpy(x) for x in input]
|
||||
if isinstance(input[0], bool):
|
||||
return np.array(input, dtype=np.bool_)
|
||||
|
||||
# Just a sequence of numbers
|
||||
if isinstance(input[0], int):
|
||||
return np.array(input, dtype=np.int64)
|
||||
if isinstance(input[0], float):
|
||||
return np.array(input)
|
||||
|
||||
return input
|
||||
|
||||
|
||||
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Converts kwargs to be compatible with ONNX Runtime."""
|
||||
new_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if key == "device":
|
||||
continue
|
||||
if key == "dtype":
|
||||
value = TORCH_TYPE_TO_ONNX[value]
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = np.array(value.cpu())
|
||||
new_kwargs[key] = value
|
||||
return new_kwargs
|
||||
|
||||
|
||||
class OrtAbortedError(RuntimeError):
|
||||
"""ONNX Runtime Aborted."""
|
||||
|
||||
|
||||
def _ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]):
|
||||
"""Run a model with ONNX Runtime."""
|
||||
|
||||
# Disable all ORT optimizations
|
||||
session_options = onnxruntime.SessionOptions()
|
||||
session_options.graph_optimization_level = (
|
||||
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
)
|
||||
session = ort.InferenceSession(
|
||||
serialized_model, session_options, providers=("CPUExecutionProvider",)
|
||||
)
|
||||
return session.run(None, ort_inputs)
|
||||
|
||||
|
||||
def _ort_session_run_return_dict(
|
||||
serialized_model: bytes, ort_inputs: Mapping[str, Any], return_dict
|
||||
) -> None:
|
||||
"""Run a model with ONNX Runtime and store the results in return_dict."""
|
||||
|
||||
try:
|
||||
return_dict["results"] = _ort_session_run(serialized_model, ort_inputs)
|
||||
return_dict["error"] = None
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return_dict["results"] = None
|
||||
return_dict["error"] = e
|
||||
|
||||
|
||||
def _safe_ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]):
|
||||
"""Run a model with ONNX Runtime in a separate process.
|
||||
|
||||
Args:
|
||||
serialized_model: Serialized ONNX model proto.
|
||||
ort_inputs: Inputs to the model.
|
||||
|
||||
Returns:
|
||||
The inference result.
|
||||
|
||||
Raises:
|
||||
OrtAbortedError if the process did not execute successfully.
|
||||
"""
|
||||
manager = multiprocessing.Manager()
|
||||
return_dict = manager.dict()
|
||||
process = multiprocessing.Process(
|
||||
target=_ort_session_run_return_dict,
|
||||
args=(serialized_model, ort_inputs, return_dict),
|
||||
)
|
||||
process.start()
|
||||
process.join()
|
||||
process.close()
|
||||
if not return_dict:
|
||||
raise OrtAbortedError
|
||||
if return_dict["error"] is not None:
|
||||
raise return_dict["error"]
|
||||
return return_dict["results"]
|
||||
|
||||
|
||||
def _format_model_and_input_information(onnx_model, inputs):
|
||||
return (
|
||||
f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}"
|
||||
)
|
||||
|
||||
|
||||
_TORCH_DTYPE_TO_ONNX_STRING = {
|
||||
torch.bool: "tensor(bool)",
|
||||
torch.uint8: "tensor(uint8)",
|
||||
torch.int8: "tensor(int8)",
|
||||
torch.int16: "tensor(int16)",
|
||||
torch.int32: "tensor(int32)",
|
||||
torch.int64: "tensor(int64)",
|
||||
torch.float16: "tensor(float16)",
|
||||
torch.float32: "tensor(float)",
|
||||
torch.float64: "tensor(double)",
|
||||
torch.complex64: "tensor(complex64)",
|
||||
torch.complex128: "tensor(complex128)",
|
||||
torch.bfloat16: "tensor(bfloat16)",
|
||||
}
|
||||
|
||||
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
|
||||
torch.bfloat16: ir.DataType.BFLOAT16,
|
||||
torch.bool: ir.DataType.BOOL,
|
||||
torch.complex128: ir.DataType.COMPLEX128,
|
||||
torch.complex64: ir.DataType.COMPLEX64,
|
||||
torch.float16: ir.DataType.FLOAT16,
|
||||
torch.float32: ir.DataType.FLOAT,
|
||||
torch.float64: ir.DataType.DOUBLE,
|
||||
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
|
||||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
||||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
||||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
||||
torch.int16: ir.DataType.INT16,
|
||||
torch.int32: ir.DataType.INT32,
|
||||
torch.int64: ir.DataType.INT64,
|
||||
torch.int8: ir.DataType.INT8,
|
||||
torch.uint8: ir.DataType.UINT8,
|
||||
torch.uint16: ir.DataType.UINT16,
|
||||
torch.uint32: ir.DataType.UINT32,
|
||||
torch.uint64: ir.DataType.UINT64,
|
||||
}
|
||||
|
||||
|
||||
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
|
||||
"""Checks if the dtype is compatible with the schema.
|
||||
|
||||
When a dtype is "compatible" with the schema, it means we can use the dtype
|
||||
to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.
|
||||
|
||||
Args:
|
||||
dtype: The torch dtype used to create sample inputs by OpInfo.
|
||||
schema: The ONNX schema of the function.
|
||||
|
||||
Returns:
|
||||
True if the dtype is compatible with the schema.
|
||||
"""
|
||||
if not schema.inputs:
|
||||
# If there are no inputs, we can't check compatibility. Assume it is compatible.
|
||||
# e.g. aten_randn has only attributes.
|
||||
return True
|
||||
if schema.inputs[0].name not in {"self", "input"}:
|
||||
# If the name of the first input is not "self" or "input",
|
||||
# it is usually an input that is not of the same type as the output.
|
||||
# We assume support in this case.
|
||||
#
|
||||
# For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)`
|
||||
# has the first input as `size`, which is an integer, but it can support
|
||||
# any dtype.
|
||||
return True
|
||||
|
||||
# Otherwise we check the type constraints of the first input.
|
||||
# For example, when dtype=torch.float32, and the op being tested has the schema
|
||||
# ```
|
||||
# OpSchema(
|
||||
# name='aten_abs',
|
||||
# domain='pkg.onnxscript.torch_lib',
|
||||
# since_version=1,
|
||||
# doc='abs(Tensor self) -> Tensor',
|
||||
# type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal',
|
||||
# allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)',
|
||||
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
|
||||
# 'tensor(bfloat16)'], description='')],
|
||||
# inputs=[OpSchema.FormalParameter(name='self', type_str='TReal',
|
||||
# description='', param_option=<FormalParameterOption.Single: 0>,
|
||||
# is_homogeneous=True, min_arity=1,
|
||||
# differentiation_category=<DifferentiationCategory.Unknown: 0>)],
|
||||
# outputs=[OpSchema.FormalParameter(name='return_val',
|
||||
# type_str='TReal', description='',
|
||||
# param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True,
|
||||
# min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
|
||||
# attributes={}
|
||||
# )
|
||||
# ```
|
||||
# we see the first input type is "TReal", corresponding to the type constraint
|
||||
# with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)',
|
||||
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
|
||||
# 'tensor(bfloat16)'].
|
||||
# Since torch.float32 (tensor(float)) is in the allowed types, we return True.
|
||||
|
||||
first_input_type_name = schema.inputs[0].type_str
|
||||
# Find the type constraint for the first input by matching the parameter name
|
||||
first_input_type_constraint = next(
|
||||
(
|
||||
x
|
||||
for x in schema.type_constraints
|
||||
if first_input_type_name in x.type_param_str
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert first_input_type_constraint is not None
|
||||
allowed_type_strs = first_input_type_constraint.allowed_type_strs
|
||||
# Here we consider seq(tensor(float)) compatible with tensor(float) as well
|
||||
return any(
|
||||
_TORCH_DTYPE_TO_ONNX_STRING[dtype] in type_str for type_str in allowed_type_strs
|
||||
)
|
||||
|
||||
|
||||
def graph_executor(
|
||||
test_name: str,
|
||||
outputs: Sequence[Any],
|
||||
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
|
||||
"""Eagerly executes a function."""
|
||||
|
||||
def _capture_graph_and_evaluate_torch_script_evaluator(
|
||||
function: Callable, args, kwargs
|
||||
) -> tuple[Any, onnx.ModelProto]:
|
||||
"""Captures the graph of a function and evaluates it using TorchScriptEvaluator."""
|
||||
|
||||
# Initialize the ONNX graph
|
||||
graph = ir.Graph(
|
||||
(),
|
||||
(),
|
||||
nodes=(),
|
||||
opset_imports={"": 18, "pkg.torch.onnx": 1},
|
||||
name="main_graph",
|
||||
)
|
||||
opset = onnxscript.opset18
|
||||
tracer = _building.OpRecorder(opset, {})
|
||||
ort_inputs = {}
|
||||
onnxscript_args: list[Any] = []
|
||||
onnxscript_kwargs = {}
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, np.ndarray):
|
||||
input_name = f"input_{i}"
|
||||
input = _tensors.SymbolicTensor(
|
||||
opset=opset,
|
||||
name=input_name,
|
||||
shape=ir.Shape(arg.shape),
|
||||
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(arg).dtype]),
|
||||
)
|
||||
graph.inputs.append(input)
|
||||
onnxscript_args.append(input)
|
||||
ort_inputs[input_name] = arg
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
# str is also a sequence but we do not want to treat it as a tensor
|
||||
sequence_input = []
|
||||
for j, subarg in enumerate(arg):
|
||||
if isinstance(subarg, np.ndarray):
|
||||
input_name = f"input_{i}_{j}"
|
||||
tensor = torch.tensor(subarg)
|
||||
input = _tensors.SymbolicTensor(
|
||||
opset=opset,
|
||||
name=input_name,
|
||||
shape=ir.Shape(tensor.shape),
|
||||
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[tensor.dtype]),
|
||||
)
|
||||
graph.inputs.append(input)
|
||||
sequence_input.append(input)
|
||||
ort_inputs[input_name] = subarg
|
||||
else:
|
||||
# Include non-numpy inputs as-is
|
||||
# For example, it could be a None value that we want to keep
|
||||
sequence_input.append(subarg)
|
||||
onnxscript_args.append(sequence_input)
|
||||
else:
|
||||
onnxscript_args.append(arg)
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
input = _tensors.SymbolicTensor(
|
||||
opset=opset,
|
||||
name=key,
|
||||
shape=ir.Shape(torch.tensor(value).shape),
|
||||
type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(value).dtype]),
|
||||
)
|
||||
graph.inputs.append(input)
|
||||
ort_inputs[key] = value
|
||||
onnxscript_kwargs[key] = input
|
||||
else:
|
||||
onnxscript_kwargs[key] = value
|
||||
|
||||
with onnxscript.evaluator.default_as(tracer):
|
||||
symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs)
|
||||
if not isinstance(symbolic_outputs, Sequence):
|
||||
symbolic_outputs = (symbolic_outputs,)
|
||||
|
||||
# We need to set the size of the output tensors for the ONNX model to be valid
|
||||
for output, symbolic_output in zip(outputs, symbolic_outputs):
|
||||
if isinstance(output, Sequence):
|
||||
# Output is a sequence
|
||||
elem_dtype = _TORCH_DTYPE_TO_ONNX[output[0].dtype]
|
||||
symbolic_output.type = ir.SequenceType(ir.TensorType(elem_dtype))
|
||||
continue
|
||||
output = (
|
||||
output
|
||||
if isinstance(output, torch.Tensor)
|
||||
else torch.tensor(output, device="cpu")
|
||||
)
|
||||
symbolic_output.shape = ir.Shape(output.shape)
|
||||
symbolic_output.dtype = _TORCH_DTYPE_TO_ONNX[output.dtype]
|
||||
|
||||
graph.outputs.extend(symbolic_outputs)
|
||||
graph.extend(tracer.nodes)
|
||||
onnx_model = ir.Model(graph, ir_version=10, producer_name="torch_test")
|
||||
for identifier, onnxscript_function in tracer.functions.items():
|
||||
if identifier in onnx_model.functions:
|
||||
continue
|
||||
if isinstance(onnxscript_function, ir.Function):
|
||||
ir_function = onnxscript_function
|
||||
else:
|
||||
# TODO: Get IR function directly when onnxscript is updated
|
||||
proto = onnxscript_function.to_function_proto()
|
||||
ir_function = ir.serde.deserialize_function(proto)
|
||||
onnx_model.functions[identifier] = ir_function
|
||||
_ir_passes.add_torchlib_common_imports(onnx_model)
|
||||
_ir_passes.add_opset_imports(onnx_model)
|
||||
# Make sure the model is valid
|
||||
model_proto = ir.to_proto(onnx_model)
|
||||
try:
|
||||
onnx.checker.check_model(model_proto, full_check=True)
|
||||
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
|
||||
raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
|
||||
model_proto = onnx.shape_inference.infer_shapes(model_proto, data_prop=True)
|
||||
try:
|
||||
if (
|
||||
os.environ.get("CATCH_ORT_SEGFAULT") == "1"
|
||||
or os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
|
||||
):
|
||||
# Use an individual process to run ONNX Runtime to catch segfaults
|
||||
return _safe_ort_session_run(
|
||||
model_proto.SerializeToString(), ort_inputs
|
||||
), model_proto
|
||||
|
||||
return _ort_session_run(
|
||||
model_proto.SerializeToString(), ort_inputs
|
||||
), model_proto
|
||||
except (
|
||||
# pylint: disable=c-extension-no-member
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException,
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
||||
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented,
|
||||
# pylint: enable=c-extension-no-member
|
||||
) as e:
|
||||
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
|
||||
error_reproduction.create_reproduction_report(
|
||||
test_name,
|
||||
model_proto,
|
||||
ort_inputs,
|
||||
e,
|
||||
"test/onnx/torchlib/test_ops.py",
|
||||
)
|
||||
raise RuntimeError(
|
||||
"ONNX Runtime failed to evaluate:\n"
|
||||
+ _format_model_and_input_information(model_proto, ort_inputs)
|
||||
) from e
|
||||
except OrtAbortedError as e:
|
||||
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
|
||||
# Save the model and inputs to a file for reproduction
|
||||
error_reproduction.create_reproduction_report(
|
||||
test_name,
|
||||
model_proto,
|
||||
ort_inputs,
|
||||
e,
|
||||
"test/onnx/torchlib/test_ops.py",
|
||||
)
|
||||
raise OrtAbortedError(
|
||||
"ONNX Runtime aborted:\n"
|
||||
+ _format_model_and_input_information(model_proto, ort_inputs)
|
||||
) from e
|
||||
except Exception as e:
|
||||
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
|
||||
error_reproduction.create_reproduction_report(
|
||||
test_name,
|
||||
model_proto,
|
||||
ort_inputs,
|
||||
e,
|
||||
"test/onnx/torchlib/test_ops.py",
|
||||
)
|
||||
raise
|
||||
|
||||
return _capture_graph_and_evaluate_torch_script_evaluator
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def normal_xfail_skip_test_behaviors(
|
||||
test_behavior: Optional[str] = None, reason: Optional[str] = None
|
||||
):
|
||||
"""This context manager is used to handle the different behaviors of xfail and skip.
|
||||
|
||||
Args:
|
||||
test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None.
|
||||
reason (optional[str]): The reason for the failure or skip.
|
||||
|
||||
Raises:
|
||||
e: Any exception raised by the test case if it's not an expected failure.
|
||||
"""
|
||||
|
||||
# We need to skip as soon as possible, as SegFault might also be a case.
|
||||
if test_behavior == "skip":
|
||||
pytest.skip(reason=reason)
|
||||
|
||||
try:
|
||||
yield
|
||||
# We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
|
||||
# to go over all test cases to find the right exception type.
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
if test_behavior is None:
|
||||
raise
|
||||
if test_behavior == "xfail":
|
||||
pytest.xfail(reason=reason)
|
||||
else:
|
||||
if test_behavior == "xfail":
|
||||
pytest.fail("Test unexpectedly passed")
|
||||
691
test/onnx/torchlib/ops_test_data.py
Normal file
691
test/onnx/torchlib/ops_test_data.py
Normal file
|
|
@ -0,0 +1,691 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
"""Test op correctness by comparing with PyTorch results.
|
||||
|
||||
## Usage
|
||||
|
||||
1. Set the env var CATCH_ORT_SEGFAULT to catch segfaults from ONNX Runtime.
|
||||
|
||||
## How to add a new operator test
|
||||
|
||||
This test use PyTorch's OpInfo mechanism to generate test cases for each operator.
|
||||
You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804
|
||||
|
||||
1. To enable test cases for an operator
|
||||
Add a `TorchLibOpInfo` entry to `TORCH_LIB_OPINFO` in `ops_test_data.py`.
|
||||
Specify `complex` if the function is designed for complex inputs.
|
||||
|
||||
The `op_info_name` in `TorchLibOpInfo` needs to be unique in the TORCH_LIB_OPINFO
|
||||
list, but complex=True ops can share the same name with non-complex ops
|
||||
because they are tested separately.
|
||||
|
||||
2. Add `.skip` and/or `.xfail` to skip or xfail tests.
|
||||
Prefer xfail over skip when possible because that allows us to monitor the behavior
|
||||
and update the test will it passes.
|
||||
|
||||
2a. If a test is now failing because of xpass, because some previous errors
|
||||
are now fixed, removed the corresponding xfail.
|
||||
|
||||
3. If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input
|
||||
wrangler function. See `_mean_input_wrangler` for an example.
|
||||
|
||||
4. To test different ONNX functions that are registered as overloads of the same
|
||||
op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each
|
||||
to one overload.
|
||||
"""
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Collection, Optional
|
||||
from typing_extensions import Self
|
||||
|
||||
import numpy as np
|
||||
import ops_test_common
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
|
||||
from torch.testing._internal import common_methods_invocations
|
||||
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
||||
|
||||
|
||||
# Create a copy of the op_db to modify
|
||||
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
|
||||
|
||||
# Append extra op_db into the op database for testing
|
||||
OPS_DB.extend(opinfo_definitions.signal.op_db)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TorchLibOpInfo:
|
||||
"""A dataclass to store the information to test an torchlib op."""
|
||||
|
||||
# The name of the op_info, e.g. "add"
|
||||
op_info_name: str
|
||||
# The torchlib ONNX Function to test
|
||||
op: Callable[..., Any]
|
||||
# The input wrangler function to adjust the input to fit the aten signature
|
||||
input_wrangler: Optional[
|
||||
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]]
|
||||
] = None
|
||||
# Whether the op is non-deterministic
|
||||
nondeterministic: bool = False
|
||||
# Whether to compare the shape only for the output[index]
|
||||
# For example: (1,2) means compare value for output[0] and shape for output[1] and [2]
|
||||
# We may be able to combine this with the nondeterministic option
|
||||
compare_shape_only_for_output: tuple[int, ...] = ()
|
||||
# Whether the function is designed for complex inputs
|
||||
complex: bool = False
|
||||
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
|
||||
# Format: {dtype: (rtol, atol)}.
|
||||
# For example: {torch.float16: (1e-3, 1e-3)}
|
||||
tolerance: dict[torch.dtype, tuple[float, float]] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
# Expected skips or fails for the test and/or subtests
|
||||
skips_or_fails: list[ops_test_common.DecorateMeta] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def get_tolerance(self, dtype: torch.dtype) -> tuple[float | None, float | None]:
|
||||
"""Returns the (rtol, atol) tolerance for the given dtype."""
|
||||
if (tolerance := self.tolerance.get(dtype)) is not None:
|
||||
return tolerance
|
||||
|
||||
# Use the PyTorch default if not specified
|
||||
# https://pytorch.org/docs/stable/testing.html
|
||||
return (None, None)
|
||||
|
||||
def skip(
|
||||
self,
|
||||
variant_name: str = "",
|
||||
*,
|
||||
reason: str,
|
||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||
device_type: Optional[str] = None,
|
||||
matcher: Optional[Callable[[Any], Any]] = None,
|
||||
enabled_if: bool = True,
|
||||
test_class_name: Optional[str] = None,
|
||||
) -> Self:
|
||||
"""Skips an OpInfo test.
|
||||
|
||||
Args:
|
||||
variant_name: Optional OpInfo variant_test_name.
|
||||
reason: The reason for skipping.
|
||||
dtypes: The dtypes to skip.
|
||||
device_type: Device type. E.g. "cpu", "cuda".
|
||||
matcher: A function that matches the test sample input. It is used only when
|
||||
the skip is in the SKIP_XFAIL_SUBTESTS list.
|
||||
enabled_if: Whether the skip is enabled.
|
||||
test_class_name: The test class name to apply the skip to. If None, the skip
|
||||
is applied to all test classes.
|
||||
"""
|
||||
self.skips_or_fails.append(
|
||||
ops_test_common.skip(
|
||||
self.op_info_name,
|
||||
variant_name,
|
||||
reason=reason,
|
||||
dtypes=dtypes,
|
||||
device_type=device_type,
|
||||
matcher=matcher,
|
||||
enabled_if=enabled_if,
|
||||
test_class_name=test_class_name,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def xfail(
|
||||
self,
|
||||
variant_name: str = "",
|
||||
*,
|
||||
reason: str,
|
||||
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||
device_type: Optional[str] = None,
|
||||
matcher: Optional[Callable[[Any], Any]] = None,
|
||||
enabled_if: bool = True,
|
||||
test_class_name: Optional[str] = None,
|
||||
) -> Self:
|
||||
"""Expects an OpInfo test to fail.
|
||||
|
||||
Args:
|
||||
variant_name: Optional OpInfo variant_test_name.
|
||||
reason: The reason for the failure.
|
||||
dtypes: The dtypes to expect the failure
|
||||
device_type: Device type. E.g. "cpu", "cuda"..
|
||||
matcher: A function that matches the test sample input. It is used only when
|
||||
the xfail is in the SKIP_XFAIL_SUBTESTS list.
|
||||
enabled_if: Whether the xfail is enabled.
|
||||
test_class_name: The test class name to apply the xfail to. If None, the
|
||||
xfail is applied to all test classes.
|
||||
"""
|
||||
self.skips_or_fails.append(
|
||||
ops_test_common.xfail(
|
||||
self.op_info_name,
|
||||
variant_name,
|
||||
reason=reason,
|
||||
dtypes=dtypes,
|
||||
device_type=device_type,
|
||||
matcher=matcher,
|
||||
enabled_if=enabled_if,
|
||||
test_class_name=test_class_name,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# Modify this section ##########################################################
|
||||
|
||||
|
||||
def _amin_amax_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if "dim" not in kwargs:
|
||||
# Supply an empty dim to match the aten signature
|
||||
kwargs["dim"] = np.array([], dtype=np.int64)
|
||||
else:
|
||||
# Convert dim to a numpy array
|
||||
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64).reshape((-1,))
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _avg_pool_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if "dim" not in kwargs:
|
||||
if len(args) > 6:
|
||||
kwargs["divisor_override"] = args.pop(6)
|
||||
if len(args) > 5:
|
||||
kwargs["count_include_pad"] = args.pop(5)
|
||||
if len(args) > 4:
|
||||
kwargs["ceil_mode"] = args.pop(4)
|
||||
if len(args) > 3:
|
||||
padding = args.pop(3)
|
||||
if isinstance(padding, np.ndarray):
|
||||
# Cannot using list(padding) here, because the element will be numpy.int64 instead of int
|
||||
padding = padding.tolist()
|
||||
kwargs["padding"] = padding
|
||||
if len(args) > 2:
|
||||
stride = args.pop(2)
|
||||
if isinstance(stride, np.ndarray):
|
||||
stride = stride.tolist()
|
||||
kwargs["stride"] = stride
|
||||
kernel_size = args.pop(1)
|
||||
if isinstance(kernel_size, np.ndarray):
|
||||
kernel_size = kernel_size.tolist()
|
||||
kwargs["kernel_size"] = kernel_size
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _cross_entropy_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if "reduction" in kwargs:
|
||||
reduction_vals = ["none", "mean", "sum"]
|
||||
value = kwargs["reduction"]
|
||||
idx = reduction_vals.index(value)
|
||||
kwargs["reduction"] = idx
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _dropout_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if "training" in kwargs:
|
||||
kwargs["train"] = kwargs["training"]
|
||||
kwargs.pop("training")
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _einsum_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# Swap the equation and tensors to revert the special handling in the OpInfo
|
||||
return [args[1], args[0]], kwargs
|
||||
|
||||
|
||||
def _embedding_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Remove arguments not present in the aten op signature."""
|
||||
kwargs.pop("max_norm", None)
|
||||
kwargs.pop("norm_type", None)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _empty_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Remove arguments not present in the aten op signature."""
|
||||
kwargs.pop("requires_grad", None)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _grid_sample_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# Convert string attriute to int as input
|
||||
inter_mode_options = {"bilinear": 0, "nearest": 1, "bicubic": 2}
|
||||
padding_mode_options = {"zeros": 0, "border": 1, "reflection": 2}
|
||||
args.append(inter_mode_options[kwargs["mode"]])
|
||||
args.append(padding_mode_options[kwargs["padding_mode"]])
|
||||
args.append(kwargs["align_corners"])
|
||||
kwargs.clear()
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _im2col_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# Move kernel_size, dilation, padding and stride from args to kwargs
|
||||
if len(args) == 5:
|
||||
# Handle stride
|
||||
stride = args.pop()
|
||||
if isinstance(stride, np.ndarray): # convert stride to list[int]
|
||||
stride = stride.tolist()
|
||||
kwargs["stride"] = stride
|
||||
# Handle padding
|
||||
padding = args.pop()
|
||||
if isinstance(padding, np.ndarray): # convert padding to list[int]
|
||||
padding = padding.tolist()
|
||||
kwargs["padding"] = padding
|
||||
# Handle dilation
|
||||
dilation = args.pop()
|
||||
if isinstance(dilation, np.ndarray): # convert dilation to list[int]
|
||||
dilation = dilation.tolist()
|
||||
kwargs["dilation"] = dilation
|
||||
# Handle kernel_size
|
||||
kernel_size = args.pop()
|
||||
if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int]
|
||||
kernel_size = kernel_size.tolist()
|
||||
kwargs["kernel_size"] = kernel_size
|
||||
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _index_put_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
args[1] = [np.array(elem) for elem in args[1]]
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _max_pool_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# Remove return_indices argument because this op doesn't accept it
|
||||
kwargs.pop("return_indices", None)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _mean_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# Make the dims as tensor
|
||||
if "dim" in kwargs:
|
||||
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _mse_loss_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if "reduction" in kwargs:
|
||||
reduction_vals = ["none", "mean", "sum"] # [0,1,2], default=1
|
||||
value = kwargs["reduction"]
|
||||
idx = reduction_vals.index(value)
|
||||
kwargs["reduction"] = idx
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _nll_loss_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if "reduction" in kwargs:
|
||||
# aten_nll_loss can only accept integer argument instead of string
|
||||
reduction_vals = ["none", "mean", "sum"]
|
||||
value = kwargs["reduction"]
|
||||
kwargs["reduction"] = reduction_vals.index(value)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _nonzero_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
kwargs.pop("as_tuple", None)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _reflection_pad2d_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
args.pop(2) # remove 'reflect' arg
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _replication_pad2d_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
args.pop(2) # remove 'replicate' arg
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _replication_pad3d_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
args.pop(2) # remove 'replicate' arg
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _roll_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if len(args) >= 3:
|
||||
if isinstance(args[2], np.ndarray): # convert dims to list[int]
|
||||
# Change dims from args to kwargs to keep tuple/list type
|
||||
dims = args.pop(2)
|
||||
kwargs["dims"] = dims.tolist()
|
||||
elif isinstance(args[2], int): # convert dims to list[int]
|
||||
dims = args.pop(2)
|
||||
kwargs["dims"] = []
|
||||
kwargs["dims"].append(dims)
|
||||
if len(args) >= 2:
|
||||
if isinstance(args[1], np.ndarray): # convert shift to list[int]
|
||||
shifts = args.pop(1)
|
||||
kwargs["shifts"] = shifts.tolist()
|
||||
elif isinstance(args[1], int):
|
||||
shifts = args.pop(1)
|
||||
kwargs["shifts"] = []
|
||||
kwargs["shifts"].append(shifts)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _scalar_tensor_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
kwargs.pop("requires_grad", None)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _scatter_reduce_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# Put the string into kwargs, otherwise FullGraph mode could not find get 'reduce' argument
|
||||
kwargs["reduce"] = args.pop(4)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _sum_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
if kwargs.get("dim") is not None:
|
||||
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _unflatten_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
args[1] = np.array(args[1], dtype=np.int64)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _where_input_wrangler(
|
||||
args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
# The aten::where op takes condition, x, y as inputs
|
||||
# Swap the first two inputs
|
||||
args[0], args[1] = args[1], args[0]
|
||||
return args, kwargs
|
||||
|
||||
|
||||
# Ops to be tested for numerical consistency between onnx and pytorch
|
||||
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
|
||||
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
|
||||
TorchLibOpInfo("abs", core_ops.aten_abs),
|
||||
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
|
||||
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
|
||||
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
|
||||
)
|
||||
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB, "arange", ("arange_start", "arange_start_step")
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",))
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB,
|
||||
"bitwise_left_shift",
|
||||
(
|
||||
"bitwise_left_shift_int8",
|
||||
"bitwise_left_shift_int16",
|
||||
"bitwise_left_shift_int32",
|
||||
"bitwise_left_shift_int64",
|
||||
),
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB,
|
||||
"bitwise_right_shift",
|
||||
(
|
||||
"bitwise_right_shift_int8",
|
||||
"bitwise_right_shift_int16",
|
||||
"bitwise_right_shift_int32",
|
||||
"bitwise_right_shift_int64",
|
||||
),
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int"))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",))
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB,
|
||||
"nn.functional.pad",
|
||||
(
|
||||
"nn.functional.reflection_pad2d",
|
||||
"nn.functional.replication_pad2d",
|
||||
"nn.functional.replication_pad3d",
|
||||
),
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB,
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
("nn.functional.scaled_dot_product_attention_bool_mask",),
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB,
|
||||
"nn.functional.celu",
|
||||
("nn.functional.celu_type_promoted",),
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",)
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(
|
||||
OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)
|
||||
)
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))
|
||||
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",))
|
||||
|
||||
# MARK: End edits here
|
||||
|
||||
|
||||
# These ops are not deterministic, so we check shape and dtype only
|
||||
NONDETERMINISTIC_OPS: frozenset[str] = frozenset(
|
||||
info.op_info_name for info in TESTED_TORCHLIB_OPS if info.nondeterministic
|
||||
)
|
||||
|
||||
COMPARE_SHAPE_ONLY_OPS: dict[
|
||||
str,
|
||||
set,
|
||||
] = {
|
||||
info.op_info_name: set(info.compare_shape_only_for_output)
|
||||
for info in TESTED_TORCHLIB_OPS
|
||||
}
|
||||
|
||||
TORCHLIB_OPINFO_MAPPING: dict[
|
||||
str,
|
||||
TorchLibOpInfo,
|
||||
] = {info.op_info_name: info for info in TESTED_TORCHLIB_OPS if not info.complex}
|
||||
|
||||
TESTED_OPS = frozenset(TORCHLIB_OPINFO_MAPPING)
|
||||
|
||||
EXPECTED_SKIPS_OR_FAILS: tuple[ops_test_common.DecorateMeta, ...] = tuple(
|
||||
functools.reduce(
|
||||
# Flatten the list
|
||||
lambda a, b: [*a, *b],
|
||||
[
|
||||
[meta for meta in info.skips_or_fails if meta.matcher is None]
|
||||
for info in TESTED_TORCHLIB_OPS
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
SKIP_XFAIL_SUBTESTS: tuple[ops_test_common.DecorateMeta, ...] = tuple(
|
||||
functools.reduce(
|
||||
# Flatten the list
|
||||
lambda a, b: [*a, *b],
|
||||
[
|
||||
[meta for meta in info.skips_or_fails if meta.matcher is not None]
|
||||
for info in TESTED_TORCHLIB_OPS
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# MARK: Complex supported functions
|
||||
COMPLEX_FUNCTION_MAPPING: dict[
|
||||
str,
|
||||
TorchLibOpInfo,
|
||||
] = {info.op_info_name: info for info in TESTED_TORCHLIB_OPS if info.complex}
|
||||
|
||||
|
||||
# Call dir(torch.ops.prims) and compare with entries in OPS_DB to create OpInfo for newly added prims ops
|
||||
PRIMS_OPS_WITH_OP_INFO = (
|
||||
"abs",
|
||||
"acos",
|
||||
"acosh",
|
||||
"add",
|
||||
"amax",
|
||||
"amin",
|
||||
"as_strided",
|
||||
"as_strided_scatter",
|
||||
"asin",
|
||||
"asinh",
|
||||
"atan",
|
||||
"atan2",
|
||||
"atanh",
|
||||
"bitwise_and",
|
||||
"bitwise_not",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"cat",
|
||||
"ceil",
|
||||
"clone",
|
||||
"conj",
|
||||
"conj_physical",
|
||||
"cos",
|
||||
"cosh",
|
||||
"digamma",
|
||||
"div",
|
||||
"empty",
|
||||
"eq",
|
||||
"erf",
|
||||
"erfc",
|
||||
"exp",
|
||||
"exp2",
|
||||
"expm1",
|
||||
"fill",
|
||||
"floor",
|
||||
"fmax",
|
||||
"fmin",
|
||||
"fmod",
|
||||
"full",
|
||||
"full_like",
|
||||
"gcd",
|
||||
"ge",
|
||||
"gt",
|
||||
"hypot",
|
||||
"igamma",
|
||||
"igammac",
|
||||
"imag",
|
||||
"isfinite",
|
||||
"le",
|
||||
"lgamma",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"lt",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"mul",
|
||||
"ne",
|
||||
"neg",
|
||||
"nextafter",
|
||||
"normal",
|
||||
"pow",
|
||||
"prod",
|
||||
"real",
|
||||
"reciprocal",
|
||||
"remainder",
|
||||
"reshape",
|
||||
"round",
|
||||
"rsqrt",
|
||||
"scalar_tensor",
|
||||
"sign",
|
||||
"signbit",
|
||||
"sin",
|
||||
"sinh",
|
||||
"sqrt",
|
||||
"squeeze",
|
||||
"sub",
|
||||
"sum",
|
||||
"svd",
|
||||
"tan",
|
||||
"tanh",
|
||||
"transpose",
|
||||
"trunc",
|
||||
"uniform",
|
||||
"where",
|
||||
)
|
||||
|
||||
for op in PRIMS_OPS_WITH_OP_INFO:
|
||||
# Duplicate opinfo for prim ops. The new names all start with "prims_". E.g. "abs" -> "prims_abs".
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, op)
|
||||
|
||||
# Duplicate cases where the prims op name is different from the torch op name
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "i0", "bessel_i0")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.bessel_j0", "bessel_j0")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.bessel_j1", "bessel_j1")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.erfcx", "erfcx")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i0e", "bessel_i0e")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i1", "bessel_i1")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i1e", "bessel_i1e")
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.ndtri", "ndtri")
|
||||
ops_test_common.duplicate_opinfo_for_prims(
|
||||
OPS_DB, "special.spherical_bessel_j0", "spherical_bessel_j0"
|
||||
)
|
||||
ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.zeta", "zeta")
|
||||
|
||||
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
|
||||
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
|
||||
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
|
||||
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
|
||||
assert NONDETERMINISTIC_OPS.issubset(
|
||||
TESTED_OPS
|
||||
), f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS"
|
||||
354
test/onnx/torchlib/test_ops.py
Normal file
354
test/onnx/torchlib/test_ops.py
Normal file
|
|
@ -0,0 +1,354 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
"""Test op correctness by comparing with PyTorch results.
|
||||
|
||||
Usage:
|
||||
|
||||
pytest test_ops.py
|
||||
|
||||
To run tests on a specific operator (e.g. torch.ceil):
|
||||
|
||||
pytest test_ops.py -k ceil
|
||||
|
||||
To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
|
||||
|
||||
pytest test_ops.py -k nn_functional_scaled_dot_product_attention
|
||||
|
||||
## Environment variables
|
||||
|
||||
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
|
||||
in onnxruntime by running the inference sessions in a separate process.
|
||||
|
||||
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
|
||||
errors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, Sequence, Tuple, TYPE_CHECKING
|
||||
|
||||
import error_reproduction
|
||||
import numpy as np
|
||||
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import onnxscript
|
||||
import ops_test_common
|
||||
import ops_test_data
|
||||
import parameterized
|
||||
|
||||
import torch
|
||||
from torch.testing._internal import common_device_type, common_utils
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.opinfo import core as opinfo_core
|
||||
|
||||
# All dtypes will be tested on the generated symbolic functions.
|
||||
# complex64 will be flattened to float32.
|
||||
TESTED_DTYPES = (
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
# Uncomment below item when we really need testing it
|
||||
# torch.bfloat16,
|
||||
# torch.float64,
|
||||
torch.bool,
|
||||
# torch.int8,
|
||||
# torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
# torch.uint8,
|
||||
)
|
||||
# NOTE: torch.complex32 is experimental in torch
|
||||
COMPLEX_TYPES = (torch.complex64,)
|
||||
|
||||
|
||||
def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
|
||||
"""Returns all dtypes except the ones specified."""
|
||||
return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)
|
||||
|
||||
|
||||
def _should_skip_xfail_test_sample(
|
||||
op_name: str, sample, dtype: torch.dtype, device_type: str
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Returns a reason if a test sample should be skipped."""
|
||||
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
|
||||
return None, None
|
||||
for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
|
||||
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
|
||||
if decorator_meta.op_name == op_name:
|
||||
assert decorator_meta.matcher is not None, "Matcher must be defined"
|
||||
if not decorator_meta.enabled_if:
|
||||
# Do not skip the test if the decorator meta is not enabled
|
||||
continue
|
||||
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
|
||||
# Not applicable for this dtype
|
||||
continue
|
||||
if (
|
||||
decorator_meta.device_type is not None
|
||||
and decorator_meta.device_type != device_type
|
||||
):
|
||||
# Not applicable for this device_type
|
||||
continue
|
||||
if decorator_meta.matcher(sample):
|
||||
return decorator_meta.test_behavior, decorator_meta.reason
|
||||
return None, None
|
||||
|
||||
|
||||
class TestFunctionValidity(common_utils.TestCase):
|
||||
@parameterized.parameterized.expand(
|
||||
[
|
||||
(info.op.name, info)
|
||||
for info in ops_test_data.TESTED_TORCHLIB_OPS
|
||||
if isinstance(info.op, onnxscript.OnnxFunction)
|
||||
],
|
||||
skip_on_empty=True,
|
||||
)
|
||||
def test_script_function_passes_checker(
|
||||
self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
|
||||
):
|
||||
function_proto = torchlib_op_info.op.to_function_proto()
|
||||
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def run_test_output_match(
|
||||
test_suite: unittest.TestCase,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
op: opinfo_core.OpInfo,
|
||||
function_executor: Callable,
|
||||
tested_op_mapping: dict[
|
||||
str,
|
||||
ops_test_data.TorchLibOpInfo,
|
||||
],
|
||||
):
|
||||
"""Base test method for testing each opset, used by instantiate_device_type_tests.
|
||||
|
||||
Args:
|
||||
test_suite: The test class instance.
|
||||
device: The PyTorch device. instantiate_device_type_tests provides this.
|
||||
dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
|
||||
op: The OpInfo instance. instantiate_device_type_tests provides this.
|
||||
function_executor: The function executor. This is a function that takes
|
||||
a function and its arguments and returns the output of the function.
|
||||
tested_op_mapping: The mapping of op name to the tested op.
|
||||
"""
|
||||
samples = op.sample_inputs(
|
||||
device,
|
||||
dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
torchlib_op_info = tested_op_mapping[op.name]
|
||||
# Obtain the input_wrangler that manipulates the OpInfo inputs
|
||||
# to match the aten operator signature
|
||||
# An example is nn.functional.upsample_nearest2d, which has a different signature
|
||||
# than the aten operator upsample_nearest2d
|
||||
onnx_function = torchlib_op_info.op
|
||||
input_wrangler = torchlib_op_info.input_wrangler
|
||||
if (
|
||||
not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
|
||||
and dtype not in COMPLEX_TYPES
|
||||
):
|
||||
test_suite.skipTest(
|
||||
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
|
||||
f"Type constraints: {onnx_function.op_schema.type_constraints}"
|
||||
)
|
||||
|
||||
# Obtain the tolerance for the op
|
||||
rtol, atol = torchlib_op_info.get_tolerance(dtype)
|
||||
for i, cpu_sample in enumerate(samples):
|
||||
inputs = (cpu_sample.input, *cpu_sample.args)
|
||||
# Provide the repr to subtest because tensors are not serializable in parallel test runs
|
||||
with test_suite.subTest(
|
||||
sample_num=i,
|
||||
inputs=repr(
|
||||
[
|
||||
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
|
||||
if isinstance(inp, torch.Tensor)
|
||||
else inp
|
||||
for inp in inputs
|
||||
]
|
||||
),
|
||||
kwargs=repr(cpu_sample.kwargs),
|
||||
):
|
||||
try:
|
||||
device_type = cpu_sample.args[0].device.type
|
||||
except (AttributeError, IndexError):
|
||||
device_type = "cpu"
|
||||
test_behavior, reason = _should_skip_xfail_test_sample(
|
||||
op.name, cpu_sample, dtype, device_type
|
||||
)
|
||||
|
||||
with ops_test_common.normal_xfail_skip_test_behaviors(
|
||||
test_behavior, reason
|
||||
):
|
||||
input_onnx = [
|
||||
ops_test_common.convert_tensor_to_numpy(x) for x in inputs
|
||||
]
|
||||
kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
|
||||
if input_wrangler:
|
||||
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
|
||||
torch_output = op(*inputs, **cpu_sample.kwargs)
|
||||
|
||||
if isinstance(torch_output, torch.Tensor) and torch.is_complex(
|
||||
torch_output
|
||||
):
|
||||
torch_output = torch.view_as_real(torch_output.resolve_conj())
|
||||
|
||||
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
|
||||
if (
|
||||
op.name.startswith("split")
|
||||
or op.name.startswith("chunk")
|
||||
or op.name.startswith("unbind")
|
||||
or op.name
|
||||
in {
|
||||
"atleast_1d_Sequence",
|
||||
"atleast_2d_Sequence",
|
||||
"atleast_3d_Sequence",
|
||||
}
|
||||
):
|
||||
# Hack for handling split, chunk and unbind which relies on SplitToSequence op.
|
||||
# Split returns a Sequence that should be treats as a single
|
||||
# value. So we wrap it into a tuple.
|
||||
# TODO(justinchuby): Find a more general solution
|
||||
reference_torch_outputs = [reference_torch_outputs]
|
||||
|
||||
test_name = test_suite.id()
|
||||
function_output, model_proto = function_executor(
|
||||
test_name, reference_torch_outputs
|
||||
)(onnx_function, input_onnx, kwargs_onnx)
|
||||
# Finally we re-flatten everything
|
||||
# TODO: add pytree structure comparison.
|
||||
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
|
||||
flattened_function_outputs, _ = pytree.tree_flatten(function_output)
|
||||
|
||||
assert flattened_torch_outputs
|
||||
assert len(flattened_torch_outputs) == len(flattened_function_outputs)
|
||||
|
||||
for j, (torch_output, function_output) in enumerate(
|
||||
zip(flattened_torch_outputs, flattened_function_outputs)
|
||||
):
|
||||
actual = torch.tensor(function_output)
|
||||
expected = (
|
||||
torch_output
|
||||
if isinstance(torch_output, torch.Tensor)
|
||||
else torch.tensor(torch_output)
|
||||
)
|
||||
|
||||
if (
|
||||
op.name in ops_test_data.NONDETERMINISTIC_OPS
|
||||
or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
|
||||
):
|
||||
# Check shape and dtype only for ops that are known to be
|
||||
# nondeterministic
|
||||
test_suite.assertEqual(actual.shape, expected.shape)
|
||||
test_suite.assertEqual(actual.dtype, expected.dtype)
|
||||
continue
|
||||
|
||||
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
actual,
|
||||
expected,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=True,
|
||||
check_device=False,
|
||||
)
|
||||
except AssertionError as e:
|
||||
if (
|
||||
os.environ.get("CREATE_REPRODUCTION_REPORT") == "1"
|
||||
and test_behavior is None
|
||||
):
|
||||
error_reproduction.create_mismatch_report(
|
||||
test_name,
|
||||
i,
|
||||
model_proto,
|
||||
inputs,
|
||||
cpu_sample.kwargs,
|
||||
actual,
|
||||
expected,
|
||||
e,
|
||||
__file__,
|
||||
)
|
||||
if len(flattened_torch_outputs) > 1:
|
||||
raise AssertionError(f"Output {j} mismatch") from e
|
||||
raise
|
||||
|
||||
|
||||
class TestOutputConsistencyFullGraph(common_utils.TestCase):
|
||||
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
|
||||
|
||||
This is a parameterized test suite.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
ort.set_seed(42)
|
||||
|
||||
@ops_test_common.add_decorate_info(
|
||||
ops_test_data.OPS_DB,
|
||||
"TestOutputConsistencyFullGraph",
|
||||
"test_output_match_opinfo_",
|
||||
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
|
||||
)
|
||||
@common_device_type.ops( # type: ignore[misc]
|
||||
[
|
||||
info
|
||||
for info in ops_test_data.OPS_DB
|
||||
if info.name in ops_test_data.TESTED_OPS
|
||||
],
|
||||
allowed_dtypes=TESTED_DTYPES,
|
||||
)
|
||||
def test_output_match_opinfo_(
|
||||
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
|
||||
):
|
||||
# Base test method for testing each op by running the full ONNX graph.
|
||||
run_test_output_match(
|
||||
self,
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
ops_test_common.graph_executor,
|
||||
ops_test_data.TORCHLIB_OPINFO_MAPPING,
|
||||
)
|
||||
|
||||
@ops_test_common.add_decorate_info(
|
||||
ops_test_data.OPS_DB,
|
||||
"TestOutputConsistencyFullGraph",
|
||||
"test_complex_output_match_opinfo_",
|
||||
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
|
||||
)
|
||||
@common_device_type.ops( # type: ignore[misc]
|
||||
[
|
||||
info
|
||||
for info in ops_test_data.OPS_DB
|
||||
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
|
||||
],
|
||||
allowed_dtypes=COMPLEX_TYPES,
|
||||
)
|
||||
def test_complex_output_match_opinfo_(
|
||||
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
|
||||
):
|
||||
"""Base test method for testing each op by running the full ONNX graph."""
|
||||
run_test_output_match(
|
||||
self,
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
ops_test_common.graph_executor,
|
||||
ops_test_data.COMPLEX_FUNCTION_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
common_device_type.instantiate_device_type_tests(
|
||||
TestOutputConsistencyFullGraph, globals(), only_for=["cpu"]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
78
torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py
Normal file
78
torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# --------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""Typings for function definitions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from onnxscript import (
|
||||
BFLOAT16,
|
||||
BOOL,
|
||||
COMPLEX128,
|
||||
COMPLEX64,
|
||||
DOUBLE,
|
||||
FLOAT,
|
||||
FLOAT16,
|
||||
INT16,
|
||||
INT32,
|
||||
INT64,
|
||||
INT8,
|
||||
STRING,
|
||||
UINT8,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
|
||||
# More detail can be found: https://pytorch.org/docs/stable/tensors.html
|
||||
|
||||
_TensorType = Union[
|
||||
BFLOAT16,
|
||||
BOOL,
|
||||
COMPLEX64,
|
||||
COMPLEX128,
|
||||
DOUBLE,
|
||||
FLOAT,
|
||||
FLOAT16,
|
||||
INT8,
|
||||
INT16,
|
||||
INT32,
|
||||
INT64,
|
||||
UINT8,
|
||||
]
|
||||
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
|
||||
IntType = Union[INT8, INT16, INT32, INT64]
|
||||
RealType = Union[
|
||||
BFLOAT16,
|
||||
FLOAT16,
|
||||
FLOAT,
|
||||
DOUBLE,
|
||||
INT8,
|
||||
INT16,
|
||||
INT32,
|
||||
INT64,
|
||||
]
|
||||
|
||||
TTensor = TypeVar("TTensor", bound=_TensorType)
|
||||
# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
|
||||
# but do not constrain the type to be the same as the other inputs/outputs
|
||||
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
|
||||
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
|
||||
TFloat = TypeVar("TFloat", bound=_FloatType)
|
||||
TFloatOrUInt8 = TypeVar(
|
||||
"TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]
|
||||
)
|
||||
TInt = TypeVar("TInt", bound=IntType)
|
||||
TReal = TypeVar("TReal", bound=RealType)
|
||||
TRealUnlessInt16OrInt8 = TypeVar(
|
||||
"TRealUnlessInt16OrInt8",
|
||||
bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64],
|
||||
)
|
||||
TRealUnlessFloat16OrInt8 = TypeVar(
|
||||
"TRealUnlessFloat16OrInt8", bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64]
|
||||
)
|
||||
TRealOrUInt8 = TypeVar("TRealOrUInt8", bound=Union[RealType, UINT8])
|
||||
TFloatHighPrecision = TypeVar("TFloatHighPrecision", bound=Union[FLOAT, DOUBLE])
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["hop"]
|
||||
__all__ = ["core", "hop"]
|
||||
|
||||
from torch.onnx._internal.exporter._torchlib.ops import hop
|
||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop
|
||||
|
|
|
|||
47
torch/onnx/_internal/exporter/_torchlib/ops/core.py
Normal file
47
torch/onnx/_internal/exporter/_torchlib/ops/core.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""torch.ops.aten operators under the `core` module."""
|
||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||
# ruff: noqa: TCH001,TCH002
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
|
||||
from onnxscript.onnx_opset import opset18 as op
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal, TRealOrUInt8
|
||||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@onnx_impl((aten.abs.default, operator.abs), trace_only=True)
|
||||
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
|
||||
"""abs(Tensor self) -> Tensor"""
|
||||
|
||||
return op.Abs(self)
|
||||
|
||||
|
||||
@onnx_impl(aten.abs.default, complex=True, trace_only=True)
|
||||
def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8:
|
||||
"""abs(Tensor self) -> Tensor"""
|
||||
|
||||
return op.ReduceL2(self, [-1], keepdims=False)
|
||||
|
||||
|
||||
@onnx_impl((aten.add.Tensor, aten.add.Scalar, operator.add), trace_only=True)
|
||||
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
|
||||
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
|
||||
if alpha != 1.0:
|
||||
alpha = op.CastLike(alpha, other)
|
||||
other = op.Mul(other, alpha)
|
||||
return op.Add(self, other)
|
||||
|
||||
|
||||
@onnx_impl((aten.add.Tensor, aten.add.Scalar), trace_only=True, complex=True)
|
||||
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
|
||||
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
|
||||
|
||||
return aten_add(self, other, alpha=alpha)
|
||||
Loading…
Reference in New Issue
Block a user