mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
- Rename `test_pytorch_common` -> `pytorch_test_common`, `test_onnx_common` -> `onnx_test_common`, removing the test_ prefix to show that the files are not test cases - Remove import * in `test_pytorch_common` and adjust to import from `testing._internal.common_utils` (where functions are actually defined) instead - Import modules only in `test_pytorch_onnx_onnxruntime` (too many to handle in a single PR in other tests) (The skips are exceptions) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81141 Approved by: https://github.com/BowenBao
130 lines
4.2 KiB
Python
130 lines
4.2 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import random
|
|
from typing import Any, Mapping, Type
|
|
|
|
import numpy as np
|
|
import onnxruntime
|
|
|
|
import torch
|
|
from torch.onnx import _constants, verification
|
|
from torch.testing._internal import common_utils
|
|
|
|
onnx_model_dir = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)),
|
|
os.pardir,
|
|
"repos",
|
|
"onnx",
|
|
"onnx",
|
|
"backend",
|
|
"test",
|
|
"data",
|
|
)
|
|
|
|
|
|
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
|
|
|
|
|
|
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
|
|
|
|
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
|
|
|
|
|
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
|
|
kwargs["ort_providers"] = _ORT_PROVIDERS
|
|
kwargs["opset_version"] = test_suite.opset_version
|
|
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
|
|
return verification.verify(*args, **kwargs)
|
|
|
|
|
|
def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
|
|
"""Combine class name with the parameterized arguments.
|
|
|
|
This function is passed to `parameterized.parameterized_class` as the
|
|
`class_name_func` argument.
|
|
"""
|
|
suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
|
|
return f"{cls.__name__}_{suffix}"
|
|
|
|
|
|
def set_rng_seed(seed):
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
|
|
class _TestONNXRuntime(common_utils.TestCase):
|
|
opset_version = _constants.onnx_default_opset
|
|
keep_initializers_as_inputs = True # For IR version 3 type export.
|
|
is_script = False
|
|
|
|
def setUp(self):
|
|
set_rng_seed(0)
|
|
onnxruntime.set_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
|
|
self.is_script_test_enabled = True
|
|
|
|
# The exported ONNX model may have less inputs than the pytorch model because of const folding.
|
|
# This mostly happens in unit test, where we widely use torch.size or torch.shape.
|
|
# So the output is only dependent on the input shape, not value.
|
|
# remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
|
|
def run_test(
|
|
self,
|
|
model,
|
|
input_args,
|
|
input_kwargs=None,
|
|
rtol=1e-3,
|
|
atol=1e-7,
|
|
do_constant_folding=True,
|
|
dynamic_axes=None,
|
|
additional_test_inputs=None,
|
|
input_names=None,
|
|
output_names=None,
|
|
fixed_batch_size=False,
|
|
training=torch.onnx.TrainingMode.EVAL,
|
|
remained_onnx_input_idx=None,
|
|
verbose=False,
|
|
):
|
|
def _run_test(m, remained_onnx_input_idx, flatten=True):
|
|
return run_model_test(
|
|
self,
|
|
m,
|
|
input_args=input_args,
|
|
input_kwargs=input_kwargs,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes,
|
|
additional_test_inputs=additional_test_inputs,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size,
|
|
training=training,
|
|
remained_onnx_input_idx=remained_onnx_input_idx,
|
|
flatten=flatten,
|
|
verbose=verbose,
|
|
)
|
|
|
|
if isinstance(remained_onnx_input_idx, dict):
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
|
|
else:
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx
|
|
|
|
is_model_script = isinstance(
|
|
model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
|
|
)
|
|
|
|
if self.is_script_test_enabled and self.is_script:
|
|
script_model = model if is_model_script else torch.jit.script(model)
|
|
_run_test(script_model, scripting_remained_onnx_input_idx, flatten=False)
|
|
|
|
if not is_model_script and not self.is_script:
|
|
_run_test(model, tracing_remained_onnx_input_idx)
|