Move ONNX's TorchModelType to pytorch_test_common to fix circ. dep. (#115353)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115353
Approved by: https://github.com/BowenBao
ghstack dependencies: #114407, #115281, #114762
This commit is contained in:
Thiago Crepaldi 2023-12-07 18:54:10 +00:00 committed by PyTorch MergeBot
parent 13d2e3eba7
commit 960ad9d94e
4 changed files with 51 additions and 53 deletions

View File

@ -11,7 +11,6 @@ import logging
import os
import unittest
import warnings
from enum import auto, Enum
from typing import (
Any,
Callable,
@ -64,11 +63,6 @@ pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
class TorchModelType(Enum):
TORCH_NN_MODULE = auto()
TORCH_EXPORT_EXPORTEDPROGRAM = auto()
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
options = verification.VerificationOptions()
@ -260,7 +254,8 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
if (
has_mutation
and self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
and self.model_type
!= pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
ref_model = _try_clone_model(model)
ref_input_args, ref_input_kwargs = _try_clone_inputs(
@ -274,7 +269,10 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
assert isinstance(ref_model, torch.nn.Module) or callable(
ref_model
), "Model must be a torch.nn.Module or callable"
if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
if (
self.model_type
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
ref_model = torch.export.export(ref_model, args=ref_input_args)
if (
self.dynamic_shapes
@ -533,7 +531,7 @@ class DecorateMeta:
test_behavior: str
matcher: Optional[Callable[[Any], bool]] = None
enabled_if: bool = True
model_type: Optional[TorchModelType] = None
model_type: Optional[pytorch_test_common.TorchModelType] = None
def contains_opset(self, opset: int) -> bool:
if self.opsets is None:
@ -553,7 +551,7 @@ def xfail(
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], bool]] = None,
enabled_if: bool = True,
model_type: Optional[TorchModelType] = None,
model_type: Optional[pytorch_test_common.TorchModelType] = None,
):
"""Expects a OpInfo test to fail.
@ -591,7 +589,7 @@ def skip(
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
model_type: Optional[TorchModelType] = None,
model_type: Optional[pytorch_test_common.TorchModelType] = None,
):
"""Skips a test case in OpInfo that we don't care about.

View File

@ -6,6 +6,7 @@ import os
import random
import sys
import unittest
from enum import auto, Enum
from typing import Optional
import numpy as np
@ -30,6 +31,11 @@ RNN_INPUT_SIZE = 5
RNN_HIDDEN_SIZE = 3
class TorchModelType(Enum):
TORCH_NN_MODULE = auto()
TORCH_EXPORT_EXPORTEDPROGRAM = auto()
def _skipper(condition, reason):
def decorator(f):
@functools.wraps(f)
@ -188,12 +194,12 @@ def skip_min_ort_version(reason: str, version: str, dynamic_only: bool = False):
return skip_dec
def skip_dynamic_fx_test(reason: str, skip_model_type=None):
def skip_dynamic_fx_test(reason: str, skip_model_type: TorchModelType = None):
"""Skip dynamic exporting test.
Args:
reason: The reason for skipping dynamic exporting test.
skip_model_type (onnx_test_common.TorchModelType): The model type to skip dynamic exporting test for.
skip_model_type (TorchModelType): The model type to skip dynamic exporting test for.
When None, model type is not used to skip dynamic tests.
Returns:
@ -344,15 +350,10 @@ def xfail_if_model_type_is_exportedprogram(reason: str):
A decorator for xfail tests.
"""
import onnx_test_common
def xfail_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if (
self.model_type
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
pytest.xfail(
reason=f"Xfail model_type==torch.export.ExportedProgram. {reason}"
)
@ -373,15 +374,10 @@ def xfail_if_model_type_is_not_exportedprogram(reason: str):
A decorator for xfail tests.
"""
import onnx_test_common
def xfail_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if (
self.model_type
!= onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
if self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
pytest.xfail(
reason=f"Xfail model_type!=torch.export.ExportedProgram. {reason}"
)

View File

@ -35,6 +35,7 @@ from typing import Any, Callable, Collection, Mapping, Optional, Tuple, Type, Un
import onnx_test_common
import parameterized
import pytorch_test_common
import torch
from onnx_test_common import skip, xfail
@ -561,7 +562,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
"arange",
matcher=lambda sample: not isinstance(sample.input, torch.Tensor),
reason="torch.export.export does not support non-tensor input (https://github.com/pytorch/pytorch/issues/115110)",
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
skip(
"cat",
@ -572,7 +573,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
"full",
matcher=lambda sample: not isinstance(sample.input, torch.Tensor),
reason="torch.export.export does not support non-tensor input (https://github.com/pytorch/pytorch/issues/115110)",
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
xfail(
"index_put",
@ -586,7 +587,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
"native_batch_norm",
matcher=lambda sample: sample.args[-3] is True
and any(arg is not None for arg in sample.args[2:4]),
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="https://github.com/pytorch/pytorch/issues/115106",
),
xfail(
@ -625,7 +626,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
"nn.functional.batch_norm",
matcher=lambda sample: sample.kwargs.get("training") is True
and any(arg is not None for arg in sample.args[2:4]),
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106",
),
skip(
@ -646,7 +647,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
xfail(
"nn.functional.embedding",
matcher=lambda sample: sample.kwargs.get("max_norm") is not None,
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="https://github.com/pytorch/pytorch/issues/115106",
),
skip_torchlib_forward_compatibility(
@ -669,11 +670,11 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
matcher=lambda sample: len(sample.input.shape) == 0
and sample.kwargs.get("as_tuple", False) is False,
reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
model_type=onnx_test_common.TorchModelType.TORCH_NN_MODULE,
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
),
xfail(
"nonzero",
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason=onnx_test_common.reason_onnx_script_does_not_support(
"aten::_assert_async.msg",
"https://github.com/pytorch/pytorch/issues/112443",
@ -720,7 +721,7 @@ class SingleOpModel(torch.nn.Module):
def _should_skip_xfail_test_sample(
op_name: str, sample, model_type: onnx_test_common.TorchModelType
op_name: str, sample, model_type: pytorch_test_common.TorchModelType
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
@ -802,8 +803,8 @@ def _parameterized_class_attrs_and_values():
itertools.product(
(opset for opset in onnx_test_common.FX_TESTED_OPSETS),
(
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
)
)
@ -838,8 +839,8 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
opset_version = -1
op_level_debug: bool = False
dynamic_shapes: bool = False
model_type: onnx_test_common.TorchModelType = (
onnx_test_common.TorchModelType.TORCH_NN_MODULE
model_type: pytorch_test_common.TorchModelType = (
pytorch_test_common.TorchModelType.TORCH_NN_MODULE
)
fp16_low_precision_list = [
@ -872,7 +873,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
# TODO(titaiwang): refactor this
# https://github.com/pytorch/pytorch/issues/105338
for opset in onnx_test_common.FX_TESTED_OPSETS:
for model_type in onnx_test_common.TorchModelType:
for model_type in pytorch_test_common.TorchModelType:
# The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}"
onnx_test_common.add_decorate_info(

View File

@ -46,8 +46,8 @@ def _parameterized_class_attrs_and_values():
(True, False),
(True, False),
(
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
)
)
@ -76,7 +76,7 @@ def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]
class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
op_level_debug: bool
dynamic_shapes: bool
model_type: onnx_test_common.TorchModelType
model_type: pytorch_test_common.TorchModelType
def setUp(self):
super().setUp()
@ -919,8 +919,8 @@ def _parameterized_class_attrs_and_values_with_fake_options():
(True, False),
(True, False),
(
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
)
)
@ -950,7 +950,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
dynamic_shapes: bool
load_checkpoint_during_init: bool
export_within_fake_mode: bool
model_type: onnx_test_common.TorchModelType
model_type: pytorch_test_common.TorchModelType
def setUp(self):
super().setUp()
@ -965,7 +965,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
create_kwargs: Callable,
load_checkpoint_during_init: bool,
export_within_fake_mode: bool,
model_type: onnx_test_common.TorchModelType,
model_type: pytorch_test_common.TorchModelType,
):
"""Test helper for FakeTensorMode-enabled exporter.
@ -995,7 +995,10 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
# Create the toy model with real weight.
real_model = create_model()
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
if model_type == onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
if (
model_type
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
real_model = torch.export.export(
real_model, args=create_args(), kwargs=create_kwargs()
)
@ -1024,7 +1027,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
if export_within_fake_mode:
if (
model_type
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
fake_model = torch.export.export(
fake_model, args=fake_args, kwargs=fake_kwargs
@ -1039,7 +1042,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
if not export_within_fake_mode:
if (
model_type
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
fake_model = torch.export.export(
fake_model, args=fake_args, kwargs=fake_kwargs
@ -1081,7 +1084,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_simple(self):
def create_model() -> nn.Module:
@ -1120,7 +1123,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
)
@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
@ -1151,7 +1154,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_large_scale_exporter_with_toy_mlp(self):
class MLPModel(nn.Module):
@ -1193,7 +1196,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_huggingface_google_t5(self):
config = transformers.T5Config(
@ -1229,7 +1232,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
@pytorch_test_common.skip_dynamic_fx_test(
"AssertionError: Dynamic shape check failed for graph inputs",
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_huggingface_openai_whisper(self):
config = transformers.WhisperConfig(