mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Creates xfail_onnxscript/skip_onnxscript so that it is clear torchlib needs to support it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110536 Approved by: https://github.com/BowenBao
801 lines
26 KiB
Python
801 lines
26 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
"""Test consistency between the output values of torch.onnx FX exported operators
|
|
and torch operators given the same inputs.
|
|
|
|
Usage:
|
|
|
|
pytest test/onnx/test_op_consistency.py
|
|
|
|
To run tests on a specific operator (e.g. torch.ceil):
|
|
|
|
pytest test/onnx/test_op_consistency.py -k ceil
|
|
pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
|
|
|
|
Read more on Running and writing tests:
|
|
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
|
|
|
Note:
|
|
|
|
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
|
|
TESTED_OPS lists. See "Modify this section"
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
from typing import Any, Callable, Collection, Optional, Tuple, Union
|
|
|
|
import onnx_test_common
|
|
|
|
import parameterized
|
|
|
|
import torch
|
|
from onnx_test_common import skip, xfail
|
|
from torch.testing._internal import (
|
|
common_device_type,
|
|
common_methods_invocations,
|
|
common_utils,
|
|
)
|
|
from torch.testing._internal.opinfo import core as opinfo_core
|
|
|
|
# Modify this section ##########################################################
|
|
# NOTE: Modify this section as more ops are supported. The list should be sorted
|
|
# alphabetically.
|
|
#
|
|
# For example, to add a test for torch.ceil:
|
|
# 1. Add "ceil" to TESTED_OPS then run pytest.
|
|
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
|
|
|
|
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
|
|
# Ops to be tested for numerical consistency between onnx and pytorch
|
|
TESTED_OPS: frozenset[str] = frozenset(
|
|
[
|
|
"abs",
|
|
"acos",
|
|
"acosh",
|
|
"add",
|
|
"addmm",
|
|
"all",
|
|
"allclose",
|
|
"amax",
|
|
"amin",
|
|
"any",
|
|
"arange",
|
|
"argmax",
|
|
"argmin",
|
|
"as_strided",
|
|
"asin",
|
|
"asinh",
|
|
"atan",
|
|
"atanh",
|
|
"atleast_1d",
|
|
"atleast_2d",
|
|
"atleast_3d",
|
|
"baddbmm",
|
|
"bmm",
|
|
"broadcast_to",
|
|
"cat",
|
|
"ceil",
|
|
"chunk",
|
|
"clamp",
|
|
"clamp_max",
|
|
"clamp_min",
|
|
"clone",
|
|
# "col2im", extra opinfo needed
|
|
"constant_pad_nd",
|
|
"contiguous",
|
|
# "copy", copy is not in OPS_DB
|
|
"cos",
|
|
"cosh",
|
|
"cross",
|
|
"cumsum",
|
|
# "detach", detach is not in OP-TEST-DB
|
|
"div",
|
|
"dot",
|
|
# "empty", non-deterministic
|
|
# "empty_like", non-deterministic
|
|
# "empty_strided", empty_strided is not in OPS_DB
|
|
"eq",
|
|
"equal",
|
|
"erf",
|
|
"exp",
|
|
"exp2",
|
|
"expand",
|
|
"expand_as",
|
|
"fill",
|
|
"flip",
|
|
"floor",
|
|
"fmod",
|
|
"full",
|
|
"full_like",
|
|
"gather",
|
|
"hstack", # aten::cat is invoked instead
|
|
"index_put",
|
|
"logit",
|
|
"mean",
|
|
"native_batch_norm",
|
|
# "new_empty", non-deterministic
|
|
# "new_empty_strided", non-deterministic
|
|
"new_full",
|
|
"new_ones",
|
|
"new_zeros",
|
|
"nn.functional.adaptive_avg_pool1d",
|
|
"nn.functional.adaptive_avg_pool2d",
|
|
"nn.functional.adaptive_avg_pool3d",
|
|
"nn.functional.avg_pool1d",
|
|
"nn.functional.avg_pool2d",
|
|
"nn.functional.avg_pool3d",
|
|
"nn.functional.batch_norm",
|
|
"nn.functional.conv1d",
|
|
# "nn.functional.conv2d", AssertionError: The values for attribute 'shape' do not match in float32
|
|
# "nn.functional.conv3d", extra opinfo needed
|
|
# "nn.functional.convolution", extra opinfo needed
|
|
"nn.functional.cross_entropy",
|
|
"nn.functional.celu",
|
|
"nn.functional.dropout",
|
|
"nn.functional.elu",
|
|
"nn.functional.embedding",
|
|
"nn.functional.embedding_bag",
|
|
"nn.functional.max_pool1d",
|
|
"nn.functional.max_pool2d",
|
|
"nn.functional.max_pool3d",
|
|
"nn.functional.nll_loss",
|
|
# "nn.functional.scaled_dot_product_attention" non-deterministic
|
|
"nonzero",
|
|
"scatter_add",
|
|
"scatter_reduce",
|
|
"square",
|
|
"stft",
|
|
"sum",
|
|
"unflatten",
|
|
"var_mean",
|
|
"vstack", # aten::cat is invoked instead
|
|
]
|
|
)
|
|
|
|
COMPLEX_TESTED_OPS = frozenset(
|
|
[
|
|
"abs",
|
|
"stft",
|
|
]
|
|
)
|
|
|
|
|
|
# NOTE: For ATen signature modifications that will break ONNX export,
|
|
# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip
|
|
# to make the signal apparent for maintainers.
|
|
def xfail_torchlib_forward_compatibility(
|
|
op_name: str,
|
|
variant_name: str = "",
|
|
*,
|
|
reason: str,
|
|
github_issue: str,
|
|
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
matcher: Optional[Callable[[Any], bool]] = None,
|
|
enabled_if: bool = True,
|
|
):
|
|
"""Prefer using this (xfail) over skip when possible.
|
|
|
|
Only skip when the test is not failing consistently.
|
|
"""
|
|
return xfail(
|
|
op_name,
|
|
variant_name=variant_name,
|
|
reason=f"{reason}. GitHub Issue: {github_issue}",
|
|
opsets=opsets,
|
|
dtypes=dtypes,
|
|
matcher=matcher,
|
|
enabled_if=enabled_if,
|
|
)
|
|
|
|
|
|
def skip_torchlib_forward_compatibility(
|
|
op_name: str,
|
|
variant_name: str = "",
|
|
*,
|
|
reason: str,
|
|
github_issue: str,
|
|
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
|
|
dtypes: Optional[Collection[torch.dtype]] = None,
|
|
matcher: Optional[Callable[[Any], Any]] = None,
|
|
enabled_if: bool = True,
|
|
):
|
|
"""Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible.
|
|
|
|
Only skip when the test is not failing consistently.
|
|
"""
|
|
return skip(
|
|
op_name,
|
|
variant_name=variant_name,
|
|
reason=f"{reason}. GitHub Issue: {github_issue}",
|
|
opsets=opsets,
|
|
dtypes=dtypes,
|
|
matcher=matcher,
|
|
enabled_if=enabled_if,
|
|
)
|
|
|
|
|
|
# fmt: off
|
|
# Turn off black formatting to keep the list compact
|
|
|
|
# Expected failures for onnx export.
|
|
# The list should be sorted alphabetically by op name.
|
|
# Q: When should I use fixme vs vs skip vs xfail?
|
|
# A: Prefer xfail over skip when possible.
|
|
# 2a. If a test is now failing because of xpass, because some previous errors
|
|
# are now fixed, removed the corresponding xfail.
|
|
# 2b. If a test is not failing consistently, use skip.
|
|
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
|
|
xfail(
|
|
"add", dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Add")
|
|
),
|
|
xfail(
|
|
"add",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support(
|
|
"Add", "int8, int16, uint8 have type issue."
|
|
),
|
|
),
|
|
xfail(
|
|
"addmm", dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
|
|
),
|
|
xfail(
|
|
"allclose", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES,
|
|
reason=onnx_test_common.reason_dynamo_does_not_support("Allclose")
|
|
),
|
|
xfail(
|
|
"amax",
|
|
dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
|
|
reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"),
|
|
),
|
|
xfail(
|
|
"amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
|
|
reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16")
|
|
),
|
|
xfail(
|
|
"arange",
|
|
dtypes=(torch.uint8,),
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"),
|
|
),
|
|
xfail(
|
|
"arange",
|
|
dtypes=(torch.int16, torch.int32),
|
|
reason="AssertionError: The values for attribute 'shape' do not match",
|
|
),
|
|
xfail(
|
|
"argmax",
|
|
dtypes=(
|
|
torch.int16,
|
|
torch.int64,
|
|
),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"ArgMax", "int16, int64"
|
|
),
|
|
),
|
|
xfail(
|
|
"argmin",
|
|
dtypes=(
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int64,
|
|
),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"ArgMin", "uint8, int8, int16, int64"
|
|
),
|
|
),
|
|
skip(
|
|
"as_strided",
|
|
variant_name="partial_views",
|
|
reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults",
|
|
),
|
|
xfail(
|
|
"baddbmm",
|
|
dtypes=(
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Matmul", "uint8, int8, int16"
|
|
),
|
|
),
|
|
xfail(
|
|
"bmm",
|
|
dtypes=(
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Matmul", "uint8, int8, int16"
|
|
),
|
|
),
|
|
skip(
|
|
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int")
|
|
),
|
|
xfail(
|
|
"chunk", dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool")
|
|
),
|
|
xfail(
|
|
"chunk",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16, torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Chunk", "uint8, int8, int16, float16"
|
|
),
|
|
),
|
|
xfail(
|
|
"clamp",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Max", "uint8, int8, int16"
|
|
),
|
|
),
|
|
xfail(
|
|
"clamp_max", dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool")
|
|
),
|
|
xfail(
|
|
"clamp_max",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Max", "uint8, int8, int16"
|
|
),
|
|
),
|
|
xfail(
|
|
"clamp_min",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Max", "uint8, int8, int16"
|
|
),
|
|
),
|
|
xfail(
|
|
"clamp_min", dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool")
|
|
),
|
|
xfail(
|
|
"constant_pad_nd",
|
|
dtypes=(torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
|
|
"Constant_pad_nd", "int16"
|
|
),
|
|
),
|
|
xfail(
|
|
"cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16")
|
|
),
|
|
xfail(
|
|
"cross",
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
|
|
),
|
|
xfail(
|
|
"dot", dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16")
|
|
),
|
|
xfail(
|
|
"eq",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"),
|
|
),
|
|
xfail(
|
|
"equal",
|
|
reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
|
|
),
|
|
xfail(
|
|
"floor",
|
|
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"),
|
|
),
|
|
xfail(
|
|
"index_put",
|
|
dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"),
|
|
),
|
|
xfail(
|
|
"index_put",
|
|
dtypes=(torch.uint8, torch.int8, torch.int16,),
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
|
|
),
|
|
xfail(
|
|
"nn.functional.adaptive_avg_pool2d",
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \
|
|
maximum recursion depth exceeded while calling a Python object"),
|
|
),
|
|
xfail(
|
|
"nn.functional.adaptive_avg_pool3d",
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"),
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool1d",
|
|
dtypes=onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool2d",
|
|
dtypes=onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool3d",
|
|
dtypes=onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
|
|
),
|
|
xfail(
|
|
"nn.functional.conv1d",
|
|
dtypes=(torch.int64,),
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"),
|
|
),
|
|
xfail(
|
|
"nn.functional.conv2d",
|
|
dtypes=(torch.int64,),
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"),
|
|
),
|
|
xfail(
|
|
"nn.functional.dropout",
|
|
reason=onnx_test_common.reason_dynamo_does_not_support("Dropout"),
|
|
),
|
|
xfail(
|
|
"nn.functional.embedding",
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.embedding_renorm.default"),
|
|
),
|
|
xfail(
|
|
"nn.functional.max_pool2d",
|
|
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"),
|
|
),
|
|
xfail(
|
|
"nn.functional.max_pool3d",
|
|
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"),
|
|
),
|
|
xfail(
|
|
"nonzero",
|
|
dtypes=(torch.int8, torch.int16),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"),
|
|
),
|
|
xfail(
|
|
"scatter_add",
|
|
dtypes=(torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="sum",
|
|
dtypes=(torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="prod",
|
|
dtypes=(torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="amin",
|
|
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="amax",
|
|
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
|
|
),
|
|
xfail(
|
|
"scatter_reduce",
|
|
variant_name="mean",
|
|
reason="ONNX doesn't support reduce='mean' option",
|
|
),
|
|
xfail(
|
|
"square",
|
|
dtypes=(torch.int8, torch.uint8, torch.int16),
|
|
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"),
|
|
),
|
|
xfail(
|
|
"stft",
|
|
reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"),
|
|
),
|
|
xfail(
|
|
"unflatten", dtypes=onnx_test_common.BOOL_TYPES,
|
|
reason=onnx_test_common.reason_onnx_does_not_support("Unflatten")
|
|
),
|
|
)
|
|
# fmt: on
|
|
|
|
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|
xfail(
|
|
"addmm", # xfail can't only use dtypes to catch all cases
|
|
matcher=lambda sample: sample.input.dtype
|
|
in (torch.uint8, torch.int8, torch.int16),
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support(
|
|
"Add", "int8, int16, uint8"
|
|
),
|
|
),
|
|
skip(
|
|
"amax",
|
|
matcher=lambda sample: len(sample.input.shape) == 0,
|
|
reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
|
|
),
|
|
skip(
|
|
"amin",
|
|
matcher=lambda sample: len(sample.input.shape) == 0,
|
|
reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
|
|
),
|
|
skip(
|
|
"cat",
|
|
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
|
|
reason="core dump - cat does not support zero-dim tensors yet",
|
|
),
|
|
xfail(
|
|
"index_put",
|
|
matcher=lambda sample: (sample.args[0][0].dtype == torch.bool)
|
|
and (sample.kwargs.get("accumulate") is False),
|
|
reason=onnx_test_common.reason_dynamo_does_not_support(
|
|
"https://github.com/pytorch/pytorch/issues/101150"
|
|
),
|
|
),
|
|
xfail(
|
|
"native_batch_norm",
|
|
matcher=lambda sample: sample.args[4]
|
|
and (
|
|
isinstance(sample.args[0], torch.Tensor) and sample.args[0].shape == (1,)
|
|
), # Edge case with training=True and mean being 1d tensor of single element.
|
|
reason="AssertionError: The values for attribute 'shape' do not match: torch.Size([1]) != torch.Size([]).",
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool1d",
|
|
matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True)
|
|
and (
|
|
sample.kwargs.get("count_include_pad") is True
|
|
or sample.input.shape[2]
|
|
% (
|
|
sample.args[0][0]
|
|
if isinstance(sample.args[0], tuple)
|
|
else sample.args[0]
|
|
)
|
|
!= 0
|
|
),
|
|
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool2d",
|
|
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
|
|
or (sample.kwargs.get("divisor_override") is not None),
|
|
reason="ONNX doesn't support divisor_override argument",
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool3d",
|
|
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True,
|
|
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
|
|
),
|
|
xfail(
|
|
"nn.functional.avg_pool3d",
|
|
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
|
|
or (sample.kwargs.get("divisor_override") is not None),
|
|
reason="ONNX doesn't support divisor_override argument",
|
|
),
|
|
skip(
|
|
"nn.functional.conv1d",
|
|
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
|
|
reason="String padding is not accepted by aten::conv1d",
|
|
),
|
|
skip(
|
|
"nn.functional.conv2d",
|
|
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
|
|
reason="String padding is not accepted by aten::conv2d",
|
|
),
|
|
skip(
|
|
"nn.functional.cross_entropy",
|
|
matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int),
|
|
reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type",
|
|
),
|
|
skip_torchlib_forward_compatibility(
|
|
"nn.functional.embedding_bag",
|
|
matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True,
|
|
reason=onnx_test_common.reason_onnx_script_does_not_support(
|
|
"'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. "
|
|
"'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided"
|
|
),
|
|
github_issue="https://github.com/microsoft/onnxscript/issues/1056",
|
|
),
|
|
skip(
|
|
"nn.functional.max_pool3d",
|
|
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
|
|
and sample.kwargs.get("padding") == 1,
|
|
reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed",
|
|
),
|
|
xfail(
|
|
"nonzero",
|
|
matcher=lambda sample: len(sample.input.shape) == 0
|
|
and sample.kwargs.get("as_tuple", False) is False,
|
|
reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
|
|
),
|
|
xfail(
|
|
"scatter_add",
|
|
matcher=lambda sample: len(sample.input.shape) == 0,
|
|
reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch",
|
|
),
|
|
skip(
|
|
"scatter_reduce",
|
|
# ONNX has not include_self parameter and default is include_self=True mode
|
|
matcher=lambda sample: sample.kwargs.get("include_self") is False,
|
|
reason="ONNX does't support include_self=False option",
|
|
),
|
|
xfail(
|
|
"unflatten",
|
|
reason="Logic not implemented for size 0 inputs in op.Reshape",
|
|
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
|
|
),
|
|
)
|
|
|
|
# END OF SECTION TO MODIFY #####################################################
|
|
|
|
|
|
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
|
|
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
|
|
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
|
|
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
|
|
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
|
|
|
|
|
|
class SingleOpModel(torch.nn.Module):
|
|
"""Test model to wrap around a single op for export."""
|
|
|
|
def __init__(self, op, kwargs):
|
|
super().__init__()
|
|
self.operator = op
|
|
self.kwargs = kwargs
|
|
|
|
def forward(self, *args):
|
|
return self.operator(*args, **self.kwargs)
|
|
|
|
|
|
def _should_skip_xfail_test_sample(
|
|
op_name: str, sample
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Returns a reason if a test sample should be skipped."""
|
|
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
|
|
return None, None
|
|
for decorator_meta in SKIP_XFAIL_SUBTESTS:
|
|
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
|
|
if decorator_meta.op_name == op_name:
|
|
assert decorator_meta.matcher is not None, "Matcher must be defined"
|
|
if decorator_meta.matcher(sample):
|
|
return decorator_meta.test_behavior, decorator_meta.reason
|
|
return None, None
|
|
|
|
|
|
def _run_test_output_match(
|
|
test_suite: onnx_test_common._TestONNXRuntime,
|
|
device: str,
|
|
dtype: torch.dtype,
|
|
op: opinfo_core.OpInfo,
|
|
):
|
|
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
|
|
assert device == "cpu"
|
|
|
|
samples = op.sample_inputs(
|
|
device,
|
|
dtype,
|
|
requires_grad=False,
|
|
)
|
|
|
|
for i, cpu_sample in enumerate(samples):
|
|
inputs = (cpu_sample.input, *cpu_sample.args)
|
|
# Provide the repr to subtest because tensors are not serializable in parallel test runs
|
|
|
|
with test_suite.subTest(
|
|
opset=test_suite.opset_version,
|
|
sample_num=i,
|
|
inputs=repr(inputs),
|
|
kwargs=repr(cpu_sample.kwargs),
|
|
):
|
|
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample)
|
|
with onnx_test_common.normal_xfail_skip_test_behaviors(
|
|
test_behavior, reason
|
|
):
|
|
model = SingleOpModel(op.op, cpu_sample.kwargs)
|
|
model.eval()
|
|
|
|
if dtype == torch.float32:
|
|
# Relax atol and rtol for float32 based on empirical results
|
|
rtol = 1e-5
|
|
atol = 2e-5
|
|
elif (
|
|
dtype == torch.float16
|
|
and op.name in test_suite.fp16_low_precision_list
|
|
):
|
|
rtol = 1e-2
|
|
atol = 1e-3
|
|
else:
|
|
rtol = None
|
|
atol = None
|
|
# Run the test
|
|
test_suite.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
|
|
model, inputs, rtol=rtol, atol=atol
|
|
)
|
|
|
|
|
|
def _get_test_class_name(cls, num, params_dict) -> str:
|
|
del cls # unused
|
|
del num # unused
|
|
return params_dict["name"]
|
|
|
|
|
|
@parameterized.parameterized_class(
|
|
[
|
|
{
|
|
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
|
|
"opset_version": opset,
|
|
}
|
|
for opset in onnx_test_common.FX_TESTED_OPSETS
|
|
],
|
|
class_name_func=_get_test_class_name,
|
|
)
|
|
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|
"""Test output consistency between exported ONNX models and PyTorch eager mode.
|
|
|
|
This is a parameterized test suite.
|
|
"""
|
|
|
|
opset_version = -1
|
|
op_level_debug: bool = False
|
|
dynamic_shapes: bool = False
|
|
|
|
fp16_low_precision_list = [
|
|
"nn.functional.batch_norm",
|
|
"native_batch_norm",
|
|
"dot",
|
|
]
|
|
|
|
@common_device_type.ops(
|
|
[op for op in OPS_DB if op.name in TESTED_OPS],
|
|
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
|
|
)
|
|
def test_output_match(self, device: str, dtype: torch.dtype, op):
|
|
"""Test the ONNX exporter."""
|
|
_run_test_output_match(self, device, dtype, op)
|
|
|
|
@common_device_type.ops(
|
|
[op for op in OPS_DB if op.name in COMPLEX_TESTED_OPS],
|
|
allowed_dtypes=onnx_test_common.COMPLEX_TYPES,
|
|
)
|
|
def test_output_match_complex(self, device: str, dtype: torch.dtype, op):
|
|
"""Test the ONNX exporter with complex dtype."""
|
|
_run_test_output_match(self, device, dtype, op)
|
|
|
|
|
|
for opset in onnx_test_common.FX_TESTED_OPSETS:
|
|
# The name needs to match the parameterized_class name.
|
|
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
|
|
onnx_test_common.add_decorate_info(
|
|
OPS_DB,
|
|
test_class_name,
|
|
"test_output_match",
|
|
opset=opset,
|
|
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
|
|
)
|
|
|
|
onnx_test_common.add_decorate_info(
|
|
OPS_DB,
|
|
test_class_name,
|
|
"test_output_match_complex",
|
|
opset=opset,
|
|
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
|
|
)
|
|
|
|
common_device_type.instantiate_device_type_tests(
|
|
globals()[test_class_name], globals(), only_for="cpu"
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|