mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fix #78119 Why: As in onnx tests verification code, we used to only consider tracing output, which ignores None type, this PR enables runtime test to keep None type in torch in script mode. 1. Move Optional Type tests from no runtime to runtime, as it's supported by ONNXRUNTIME. 2. Add ignoreNone flag for output comparison of internal tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/83184 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
141 lines
4.5 KiB
Python
141 lines
4.5 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import random
|
|
from typing import Any, Mapping, Type
|
|
|
|
import numpy as np
|
|
import onnxruntime
|
|
|
|
import torch
|
|
from torch.onnx import _constants, verification
|
|
from torch.testing._internal import common_utils
|
|
|
|
onnx_model_dir = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)),
|
|
os.pardir,
|
|
"repos",
|
|
"onnx",
|
|
"onnx",
|
|
"backend",
|
|
"test",
|
|
"data",
|
|
)
|
|
|
|
|
|
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
|
|
|
|
|
|
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
|
|
|
|
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
|
|
|
|
|
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
|
|
kwargs["ort_providers"] = _ORT_PROVIDERS
|
|
kwargs["opset_version"] = test_suite.opset_version
|
|
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
|
|
if hasattr(test_suite, "check_shape"):
|
|
kwargs["check_shape"] = test_suite.check_shape
|
|
if hasattr(test_suite, "check_dtype"):
|
|
kwargs["check_dtype"] = test_suite.check_dtype
|
|
return verification.verify(*args, **kwargs)
|
|
|
|
|
|
def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
|
|
"""Combine class name with the parameterized arguments.
|
|
|
|
This function is passed to `parameterized.parameterized_class` as the
|
|
`class_name_func` argument.
|
|
"""
|
|
suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
|
|
return f"{cls.__name__}_{suffix}"
|
|
|
|
|
|
def set_rng_seed(seed):
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
|
|
class _TestONNXRuntime(common_utils.TestCase):
|
|
opset_version = _constants.onnx_default_opset
|
|
keep_initializers_as_inputs = True # For IR version 3 type export.
|
|
is_script = False
|
|
check_shape = True
|
|
check_dtype = True
|
|
|
|
def setUp(self):
|
|
set_rng_seed(0)
|
|
onnxruntime.set_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
|
|
self.is_script_test_enabled = True
|
|
|
|
# The exported ONNX model may have less inputs than the pytorch model because of const folding.
|
|
# This mostly happens in unit test, where we widely use torch.size or torch.shape.
|
|
# So the output is only dependent on the input shape, not value.
|
|
# remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
|
|
def run_test(
|
|
self,
|
|
model,
|
|
input_args,
|
|
input_kwargs=None,
|
|
rtol=1e-3,
|
|
atol=1e-7,
|
|
do_constant_folding=True,
|
|
dynamic_axes=None,
|
|
additional_test_inputs=None,
|
|
input_names=None,
|
|
output_names=None,
|
|
fixed_batch_size=False,
|
|
training=torch.onnx.TrainingMode.EVAL,
|
|
remained_onnx_input_idx=None,
|
|
verbose=False,
|
|
):
|
|
def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
|
|
return run_model_test(
|
|
self,
|
|
m,
|
|
input_args=input_args,
|
|
input_kwargs=input_kwargs,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes,
|
|
additional_test_inputs=additional_test_inputs,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size,
|
|
training=training,
|
|
remained_onnx_input_idx=remained_onnx_input_idx,
|
|
flatten=flatten,
|
|
ignore_none=ignore_none,
|
|
verbose=verbose,
|
|
)
|
|
|
|
if isinstance(remained_onnx_input_idx, dict):
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
|
|
else:
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx
|
|
|
|
is_model_script = isinstance(
|
|
model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
|
|
)
|
|
|
|
if self.is_script_test_enabled and self.is_script:
|
|
script_model = model if is_model_script else torch.jit.script(model)
|
|
_run_test(
|
|
script_model,
|
|
scripting_remained_onnx_input_idx,
|
|
flatten=False,
|
|
ignore_none=False,
|
|
)
|
|
if not is_model_script and not self.is_script:
|
|
_run_test(model, tracing_remained_onnx_input_idx)
|