mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Move ONNX's TorchModelType to pytorch_test_common to fix circ. dep. (#115353)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115353 Approved by: https://github.com/BowenBao ghstack dependencies: #114407, #115281, #114762
This commit is contained in:
parent
13d2e3eba7
commit
960ad9d94e
|
|
@ -11,7 +11,6 @@ import logging
|
|||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
|
@ -64,11 +63,6 @@ pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
|
|||
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
|
||||
|
||||
|
||||
class TorchModelType(Enum):
|
||||
TORCH_NN_MODULE = auto()
|
||||
TORCH_EXPORT_EXPORTEDPROGRAM = auto()
|
||||
|
||||
|
||||
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
|
||||
options = verification.VerificationOptions()
|
||||
|
||||
|
|
@ -260,7 +254,8 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
|||
|
||||
if (
|
||||
has_mutation
|
||||
and self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
and self.model_type
|
||||
!= pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
ref_model = _try_clone_model(model)
|
||||
ref_input_args, ref_input_kwargs = _try_clone_inputs(
|
||||
|
|
@ -274,7 +269,10 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
|||
assert isinstance(ref_model, torch.nn.Module) or callable(
|
||||
ref_model
|
||||
), "Model must be a torch.nn.Module or callable"
|
||||
if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
|
||||
if (
|
||||
self.model_type
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
ref_model = torch.export.export(ref_model, args=ref_input_args)
|
||||
if (
|
||||
self.dynamic_shapes
|
||||
|
|
@ -533,7 +531,7 @@ class DecorateMeta:
|
|||
test_behavior: str
|
||||
matcher: Optional[Callable[[Any], bool]] = None
|
||||
enabled_if: bool = True
|
||||
model_type: Optional[TorchModelType] = None
|
||||
model_type: Optional[pytorch_test_common.TorchModelType] = None
|
||||
|
||||
def contains_opset(self, opset: int) -> bool:
|
||||
if self.opsets is None:
|
||||
|
|
@ -553,7 +551,7 @@ def xfail(
|
|||
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||
matcher: Optional[Callable[[Any], bool]] = None,
|
||||
enabled_if: bool = True,
|
||||
model_type: Optional[TorchModelType] = None,
|
||||
model_type: Optional[pytorch_test_common.TorchModelType] = None,
|
||||
):
|
||||
"""Expects a OpInfo test to fail.
|
||||
|
||||
|
|
@ -591,7 +589,7 @@ def skip(
|
|||
dtypes: Optional[Collection[torch.dtype]] = None,
|
||||
matcher: Optional[Callable[[Any], Any]] = None,
|
||||
enabled_if: bool = True,
|
||||
model_type: Optional[TorchModelType] = None,
|
||||
model_type: Optional[pytorch_test_common.TorchModelType] = None,
|
||||
):
|
||||
"""Skips a test case in OpInfo that we don't care about.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import os
|
|||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from enum import auto, Enum
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -30,6 +31,11 @@ RNN_INPUT_SIZE = 5
|
|||
RNN_HIDDEN_SIZE = 3
|
||||
|
||||
|
||||
class TorchModelType(Enum):
|
||||
TORCH_NN_MODULE = auto()
|
||||
TORCH_EXPORT_EXPORTEDPROGRAM = auto()
|
||||
|
||||
|
||||
def _skipper(condition, reason):
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
|
|
@ -188,12 +194,12 @@ def skip_min_ort_version(reason: str, version: str, dynamic_only: bool = False):
|
|||
return skip_dec
|
||||
|
||||
|
||||
def skip_dynamic_fx_test(reason: str, skip_model_type=None):
|
||||
def skip_dynamic_fx_test(reason: str, skip_model_type: TorchModelType = None):
|
||||
"""Skip dynamic exporting test.
|
||||
|
||||
Args:
|
||||
reason: The reason for skipping dynamic exporting test.
|
||||
skip_model_type (onnx_test_common.TorchModelType): The model type to skip dynamic exporting test for.
|
||||
skip_model_type (TorchModelType): The model type to skip dynamic exporting test for.
|
||||
When None, model type is not used to skip dynamic tests.
|
||||
|
||||
Returns:
|
||||
|
|
@ -344,15 +350,10 @@ def xfail_if_model_type_is_exportedprogram(reason: str):
|
|||
A decorator for xfail tests.
|
||||
"""
|
||||
|
||||
import onnx_test_common
|
||||
|
||||
def xfail_dec(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (
|
||||
self.model_type
|
||||
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
|
||||
pytest.xfail(
|
||||
reason=f"Xfail model_type==torch.export.ExportedProgram. {reason}"
|
||||
)
|
||||
|
|
@ -373,15 +374,10 @@ def xfail_if_model_type_is_not_exportedprogram(reason: str):
|
|||
A decorator for xfail tests.
|
||||
"""
|
||||
|
||||
import onnx_test_common
|
||||
|
||||
def xfail_dec(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (
|
||||
self.model_type
|
||||
!= onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
if self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
|
||||
pytest.xfail(
|
||||
reason=f"Xfail model_type!=torch.export.ExportedProgram. {reason}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from typing import Any, Callable, Collection, Mapping, Optional, Tuple, Type, Un
|
|||
import onnx_test_common
|
||||
|
||||
import parameterized
|
||||
import pytorch_test_common
|
||||
|
||||
import torch
|
||||
from onnx_test_common import skip, xfail
|
||||
|
|
@ -561,7 +562,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||
"arange",
|
||||
matcher=lambda sample: not isinstance(sample.input, torch.Tensor),
|
||||
reason="torch.export.export does not support non-tensor input (https://github.com/pytorch/pytorch/issues/115110)",
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
skip(
|
||||
"cat",
|
||||
|
|
@ -572,7 +573,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||
"full",
|
||||
matcher=lambda sample: not isinstance(sample.input, torch.Tensor),
|
||||
reason="torch.export.export does not support non-tensor input (https://github.com/pytorch/pytorch/issues/115110)",
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
xfail(
|
||||
"index_put",
|
||||
|
|
@ -586,7 +587,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||
"native_batch_norm",
|
||||
matcher=lambda sample: sample.args[-3] is True
|
||||
and any(arg is not None for arg in sample.args[2:4]),
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
reason="https://github.com/pytorch/pytorch/issues/115106",
|
||||
),
|
||||
xfail(
|
||||
|
|
@ -625,7 +626,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||
"nn.functional.batch_norm",
|
||||
matcher=lambda sample: sample.kwargs.get("training") is True
|
||||
and any(arg is not None for arg in sample.args[2:4]),
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106",
|
||||
),
|
||||
skip(
|
||||
|
|
@ -646,7 +647,7 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||
xfail(
|
||||
"nn.functional.embedding",
|
||||
matcher=lambda sample: sample.kwargs.get("max_norm") is not None,
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
reason="https://github.com/pytorch/pytorch/issues/115106",
|
||||
),
|
||||
skip_torchlib_forward_compatibility(
|
||||
|
|
@ -669,11 +670,11 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||
matcher=lambda sample: len(sample.input.shape) == 0
|
||||
and sample.kwargs.get("as_tuple", False) is False,
|
||||
reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
),
|
||||
xfail(
|
||||
"nonzero",
|
||||
model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
reason=onnx_test_common.reason_onnx_script_does_not_support(
|
||||
"aten::_assert_async.msg",
|
||||
"https://github.com/pytorch/pytorch/issues/112443",
|
||||
|
|
@ -720,7 +721,7 @@ class SingleOpModel(torch.nn.Module):
|
|||
|
||||
|
||||
def _should_skip_xfail_test_sample(
|
||||
op_name: str, sample, model_type: onnx_test_common.TorchModelType
|
||||
op_name: str, sample, model_type: pytorch_test_common.TorchModelType
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Returns a reason if a test sample should be skipped."""
|
||||
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
|
||||
|
|
@ -802,8 +803,8 @@ def _parameterized_class_attrs_and_values():
|
|||
itertools.product(
|
||||
(opset for opset in onnx_test_common.FX_TESTED_OPSETS),
|
||||
(
|
||||
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -838,8 +839,8 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|||
opset_version = -1
|
||||
op_level_debug: bool = False
|
||||
dynamic_shapes: bool = False
|
||||
model_type: onnx_test_common.TorchModelType = (
|
||||
onnx_test_common.TorchModelType.TORCH_NN_MODULE
|
||||
model_type: pytorch_test_common.TorchModelType = (
|
||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE
|
||||
)
|
||||
|
||||
fp16_low_precision_list = [
|
||||
|
|
@ -872,7 +873,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
|||
# TODO(titaiwang): refactor this
|
||||
# https://github.com/pytorch/pytorch/issues/105338
|
||||
for opset in onnx_test_common.FX_TESTED_OPSETS:
|
||||
for model_type in onnx_test_common.TorchModelType:
|
||||
for model_type in pytorch_test_common.TorchModelType:
|
||||
# The name needs to match the parameterized_class name.
|
||||
test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}"
|
||||
onnx_test_common.add_decorate_info(
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ def _parameterized_class_attrs_and_values():
|
|||
(True, False),
|
||||
(True, False),
|
||||
(
|
||||
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -76,7 +76,7 @@ def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]
|
|||
class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
op_level_debug: bool
|
||||
dynamic_shapes: bool
|
||||
model_type: onnx_test_common.TorchModelType
|
||||
model_type: pytorch_test_common.TorchModelType
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
|
@ -919,8 +919,8 @@ def _parameterized_class_attrs_and_values_with_fake_options():
|
|||
(True, False),
|
||||
(True, False),
|
||||
(
|
||||
onnx_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -950,7 +950,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
dynamic_shapes: bool
|
||||
load_checkpoint_during_init: bool
|
||||
export_within_fake_mode: bool
|
||||
model_type: onnx_test_common.TorchModelType
|
||||
model_type: pytorch_test_common.TorchModelType
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
|
@ -965,7 +965,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
create_kwargs: Callable,
|
||||
load_checkpoint_during_init: bool,
|
||||
export_within_fake_mode: bool,
|
||||
model_type: onnx_test_common.TorchModelType,
|
||||
model_type: pytorch_test_common.TorchModelType,
|
||||
):
|
||||
"""Test helper for FakeTensorMode-enabled exporter.
|
||||
|
||||
|
|
@ -995,7 +995,10 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
# Create the toy model with real weight.
|
||||
real_model = create_model()
|
||||
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
|
||||
if model_type == onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
|
||||
if (
|
||||
model_type
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
real_model = torch.export.export(
|
||||
real_model, args=create_args(), kwargs=create_kwargs()
|
||||
)
|
||||
|
|
@ -1024,7 +1027,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
if export_within_fake_mode:
|
||||
if (
|
||||
model_type
|
||||
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
fake_model = torch.export.export(
|
||||
fake_model, args=fake_args, kwargs=fake_kwargs
|
||||
|
|
@ -1039,7 +1042,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
if not export_within_fake_mode:
|
||||
if (
|
||||
model_type
|
||||
== onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
fake_model = torch.export.export(
|
||||
fake_model, args=fake_args, kwargs=fake_kwargs
|
||||
|
|
@ -1081,7 +1084,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
"AssertionError: Dynamic shape check failed for graph inputs",
|
||||
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_fake_tensor_mode_simple(self):
|
||||
def create_model() -> nn.Module:
|
||||
|
|
@ -1120,7 +1123,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
"AssertionError: Dynamic shape check failed for graph inputs",
|
||||
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_large_scale_exporter_with_tiny_gpt2(self):
|
||||
model_name = "sshleifer/tiny-gpt2"
|
||||
|
|
@ -1151,7 +1154,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
"AssertionError: Dynamic shape check failed for graph inputs",
|
||||
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_large_scale_exporter_with_toy_mlp(self):
|
||||
class MLPModel(nn.Module):
|
||||
|
|
@ -1193,7 +1196,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
"AssertionError: Dynamic shape check failed for graph inputs",
|
||||
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_fake_tensor_mode_huggingface_google_t5(self):
|
||||
config = transformers.T5Config(
|
||||
|
|
@ -1229,7 +1232,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
"AssertionError: Dynamic shape check failed for graph inputs",
|
||||
skip_model_type=onnx_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
skip_model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_fake_tensor_mode_huggingface_openai_whisper(self):
|
||||
config = transformers.WhisperConfig(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user