mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
1. Expand additional_test_inputs to include kwargs 2. Revisit and update tests status by adding ops 3. Disabling dtype -1 assignment avoids potential bugs 4. Expand input/output to accept buit-in type, but they are not dynamically captured by dynamo.export right now, and they would be added as constant input to op.targets. 5. Move run_test_with_fx_to_onnx_exporter_and_onnx_runtime to onnx_test_common.py <!-- copilot:all --> ### <samp>🤖 Generated by Copilot at 3c03579</samp> ### Summary 🛠️🧪🚀 <!-- 1. 🛠️ for updating the `filter_incompatible_and_dtype_convert_kwargs` function 2. 🧪 for updating the test function and test cases 3. 🚀 for adding support for new operators and scalar types --> This pull request improves the ONNX export support for scalar types and some ATen operators in PyTorch. It updates the test framework, the input and output adapters, the function dispatcher and the ONNX script generator to handle these cases. It also fixes or removes some failing or outdated tests. > _We defy the limits of the ONNX script_ > _We export the models with scalar and copy_ > _We filter and convert the kwargs of dtype_ > _We run the tests with FX and docstring_ ### Walkthrough * Update the `_InputArgsType` type annotation and the `_run_test_with_fx_to_onnx_exporter_and_onnx_runtime` function signature and docstring to handle int, float and bool inputs for some ONNX operators ([link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL44-R46), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL144-R157), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL155-R164), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL162-R172), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL201-R224), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L197-R199), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L291-R293)) * Update the `filter_incompatible_and_dtype_convert_kwargs` function to omit the `dtype` argument if it is None ([link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-cabc3e58713d6fe7ab764ade4f2692f6753402322a7b542397cad16fcc72cf4bL203-R205)) * Update the test cases in `test_fx_to_onnx_with_onnxruntime.py` to use the `input_kwargs` parameter as a mapping, to fix the format of the `additional_test_inputs` parameter, and to add or remove `xfail`, `skip_dynamic_fx_test` and `skip_min_ort_version` decorators as needed ([link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL320-R336), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL330-R353), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL357-R380), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL452-L470), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL488-R486), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL509-R510), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbR543), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL559-R565), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL578-R580), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL597-R599), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL611-R620), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL636-R636), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL656-R659), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL672-R675), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL691-R698), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL709-R714), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL732-R730), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL752-R750), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL773-R771), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbR797-R803), [link](https://github.com/pytorch/pytorch/pull/99434/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL807-R816)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99434 Approved by: https://github.com/justinchuby
353 lines
12 KiB
Python
353 lines
12 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
|
|
import dataclasses
|
|
import io
|
|
import os
|
|
import warnings
|
|
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union
|
|
|
|
import numpy as np
|
|
|
|
import onnxruntime
|
|
import pytorch_test_common
|
|
import torch
|
|
from torch.onnx import _constants, verification
|
|
from torch.onnx._internal import _beartype
|
|
from torch.types import Number
|
|
|
|
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
|
_ModelType = Union[torch.nn.Module, Callable]
|
|
_InputArgsType = Optional[
|
|
Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]]
|
|
]
|
|
_OutputsType = Sequence[_NumericType]
|
|
|
|
onnx_model_dir = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)),
|
|
os.pardir,
|
|
"repos",
|
|
"onnx",
|
|
"onnx",
|
|
"backend",
|
|
"test",
|
|
"data",
|
|
)
|
|
|
|
|
|
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
|
|
|
|
|
|
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
|
|
|
|
|
|
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
|
|
options = verification.VerificationOptions()
|
|
|
|
kwargs["opset_version"] = test_suite.opset_version
|
|
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
|
|
if hasattr(test_suite, "check_shape"):
|
|
options.check_shape = test_suite.check_shape
|
|
if hasattr(test_suite, "check_dtype"):
|
|
options.check_dtype = test_suite.check_dtype
|
|
|
|
names = {f.name for f in dataclasses.fields(options)}
|
|
keywords_to_pop = []
|
|
for k, v in kwargs.items():
|
|
if k in names:
|
|
setattr(options, k, v)
|
|
keywords_to_pop.append(k)
|
|
for k in keywords_to_pop:
|
|
kwargs.pop(k)
|
|
|
|
return verification.verify(*args, options=options, **kwargs)
|
|
|
|
|
|
def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
|
|
"""Combine class name with the parameterized arguments.
|
|
|
|
This function is passed to `parameterized.parameterized_class` as the
|
|
`class_name_func` argument.
|
|
"""
|
|
suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
|
|
return f"{cls.__name__}_{suffix}"
|
|
|
|
|
|
class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
|
opset_version = _constants.ONNX_DEFAULT_OPSET
|
|
keep_initializers_as_inputs = True # For IR version 3 type export.
|
|
is_script = False
|
|
check_shape = True
|
|
check_dtype = True
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
onnxruntime.set_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
|
|
self.is_script_test_enabled = True
|
|
|
|
# The exported ONNX model may have less inputs than the pytorch model because of const folding.
|
|
# This mostly happens in unit test, where we widely use torch.size or torch.shape.
|
|
# So the output is only dependent on the input shape, not value.
|
|
# remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
|
|
def run_test(
|
|
self,
|
|
model,
|
|
input_args,
|
|
input_kwargs=None,
|
|
rtol=1e-3,
|
|
atol=1e-7,
|
|
do_constant_folding=True,
|
|
dynamic_axes=None,
|
|
additional_test_inputs=None,
|
|
input_names=None,
|
|
output_names=None,
|
|
fixed_batch_size=False,
|
|
training=torch.onnx.TrainingMode.EVAL,
|
|
remained_onnx_input_idx=None,
|
|
verbose=False,
|
|
):
|
|
def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
|
|
return run_model_test(
|
|
self,
|
|
m,
|
|
input_args=input_args,
|
|
input_kwargs=input_kwargs,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes,
|
|
additional_test_inputs=additional_test_inputs,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size,
|
|
training=training,
|
|
remained_onnx_input_idx=remained_onnx_input_idx,
|
|
flatten=flatten,
|
|
ignore_none=ignore_none,
|
|
verbose=verbose,
|
|
)
|
|
|
|
if isinstance(remained_onnx_input_idx, dict):
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
|
|
else:
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx
|
|
|
|
is_model_script = isinstance(
|
|
model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
|
|
)
|
|
|
|
if self.is_script_test_enabled and self.is_script:
|
|
script_model = model if is_model_script else torch.jit.script(model)
|
|
_run_test(
|
|
script_model,
|
|
scripting_remained_onnx_input_idx,
|
|
flatten=False,
|
|
ignore_none=False,
|
|
)
|
|
if not is_model_script and not self.is_script:
|
|
_run_test(model, tracing_remained_onnx_input_idx)
|
|
|
|
@_beartype.beartype
|
|
def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
|
|
self,
|
|
model: _ModelType,
|
|
input_args: Sequence[_InputArgsType],
|
|
input_kwargs: Optional[Mapping[str, _InputArgsType]] = None,
|
|
rtol: float = 1e-3,
|
|
atol: float = 1e-7,
|
|
opset_version: int = 18,
|
|
has_mutation: bool = False,
|
|
additional_test_inputs: Optional[
|
|
List[
|
|
Union[
|
|
Tuple[Sequence[_InputArgsType], Mapping[str, _InputArgsType]],
|
|
Tuple[Sequence[_InputArgsType]],
|
|
]
|
|
]
|
|
] = None,
|
|
):
|
|
"""Compare the results of PyTorch model with exported ONNX model
|
|
|
|
Args:
|
|
model (_ModelType): PyTorch model
|
|
input_args (Sequence[_InputArgsType]): torch input arguments
|
|
input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs
|
|
rtol (float, optional): relative tolerance. Defaults to 1e-3.
|
|
atol (float, optional): absolute tolerance. Defaults to 1e-7.
|
|
opset_version (int, optional): ONNX opset version. Defaults to 18.
|
|
has_mutation (bool, optional): Whether the model mutates its input or state.
|
|
`mutation` as `True` incurs extra overhead of cloning the inputs and model.
|
|
Defaults to False.
|
|
additional_test_inputs: Test the models with another dataset input, which
|
|
is designed for dynamic axes testing. Defaults to None. It's a list of
|
|
different input sets in tuples. Inside tuple, the first element is a tuple
|
|
of args, and the second element is a dict of kwargs. Remember to put comma
|
|
even if the following element is not provided.
|
|
For example,
|
|
additional_test_inputs = [((args1, args2), {"kwargs":1}), ((args1,),), ((), {"kwargs":1})]
|
|
|
|
"""
|
|
|
|
# avoid mutable data structure
|
|
if input_kwargs is None:
|
|
input_kwargs = {}
|
|
|
|
if has_mutation:
|
|
ref_model = _try_clone_model(model)
|
|
ref_input_args, ref_input_kwargs = _try_clone_inputs(
|
|
input_args, input_kwargs
|
|
)
|
|
else:
|
|
ref_model = model
|
|
ref_input_args = input_args
|
|
ref_input_kwargs = input_kwargs
|
|
|
|
# Feed args and kwargs into exporter.
|
|
# Note that exporter should flatten kwargs into positional args the exported model;
|
|
# since ONNX doesn't represent kwargs.
|
|
export_output = torch.onnx.dynamo_export(
|
|
ref_model,
|
|
*ref_input_args,
|
|
**ref_input_kwargs,
|
|
export_options=torch.onnx.ExportOptions(
|
|
opset_version=opset_version,
|
|
op_level_debug=self.op_level_debug,
|
|
dynamic_shapes=self.dynamic_shapes,
|
|
),
|
|
)
|
|
|
|
_compare_pytorch_onnx_with_ort(
|
|
export_output,
|
|
model,
|
|
input_args,
|
|
input_kwargs,
|
|
atol,
|
|
rtol,
|
|
has_mutation=has_mutation,
|
|
)
|
|
# This confirms the exported mode accepts different input shapes
|
|
# when dynamic shape is enabled.
|
|
if additional_test_inputs and self.dynamic_shapes:
|
|
for another_input in additional_test_inputs:
|
|
if len(another_input) > 2:
|
|
raise ValueError(
|
|
f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}"
|
|
)
|
|
additional_input_args = another_input[0]
|
|
additional_input_kwargs = (
|
|
another_input[1]
|
|
if len(another_input) == 2 and another_input[1] is not None
|
|
else {}
|
|
)
|
|
_compare_pytorch_onnx_with_ort(
|
|
export_output,
|
|
model,
|
|
additional_input_args,
|
|
additional_input_kwargs,
|
|
atol,
|
|
rtol,
|
|
has_mutation=has_mutation,
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def run_ort(
|
|
onnx_model: Union[str, torch.onnx.ExportOutput],
|
|
pytorch_inputs: Sequence[_InputArgsType],
|
|
) -> _OutputsType:
|
|
"""Run ORT on the given ONNX model and inputs
|
|
|
|
Used in test_fx_to_onnx_with_onnxruntime.py
|
|
|
|
Args:
|
|
onnx_model (Union[str, torch.onnx.ExportOutput]): Converter ONNX model
|
|
pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs
|
|
|
|
Raises:
|
|
AssertionError: ONNX and PyTorch should have the same input sizes
|
|
|
|
Returns:
|
|
_OutputsType: ONNX model predictions
|
|
"""
|
|
if isinstance(onnx_model, torch.onnx.ExportOutput):
|
|
buffer = io.BytesIO()
|
|
onnx_model.save(buffer)
|
|
ort_model = buffer.getvalue()
|
|
else:
|
|
ort_model = onnx_model
|
|
session = onnxruntime.InferenceSession(
|
|
ort_model, providers=["CPUExecutionProvider"]
|
|
)
|
|
input_names = [ort_input.name for ort_input in session.get_inputs()]
|
|
if len(input_names) != len(pytorch_inputs):
|
|
raise AssertionError(
|
|
f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
|
|
)
|
|
return session.run(
|
|
None, {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)}
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _try_clone_model(model: _ModelType) -> _ModelType:
|
|
"""Used for preserving original model in case forward mutates model states."""
|
|
try:
|
|
return copy.deepcopy(model)
|
|
except Exception:
|
|
warnings.warn(
|
|
"Failed to clone model. Model state might be mutated during verification."
|
|
)
|
|
return model
|
|
|
|
|
|
@_beartype.beartype
|
|
def _try_clone_inputs(input_args, input_kwargs):
|
|
ref_input_args = copy.deepcopy(input_args)
|
|
ref_input_kwargs = copy.deepcopy(input_kwargs)
|
|
return ref_input_args, ref_input_kwargs
|
|
|
|
|
|
@_beartype.beartype
|
|
def _compare_pytorch_onnx_with_ort(
|
|
export_output: torch.onnx.ExportOutput,
|
|
model: _ModelType,
|
|
input_args: Sequence[_InputArgsType],
|
|
input_kwargs: Mapping[str, _InputArgsType],
|
|
atol: float,
|
|
rtol: float,
|
|
has_mutation: bool = False,
|
|
):
|
|
if has_mutation:
|
|
ref_model = _try_clone_model(model)
|
|
ref_input_args, ref_input_kwargs = _try_clone_inputs(input_args, input_kwargs)
|
|
else:
|
|
ref_model = model
|
|
ref_input_args = input_args
|
|
ref_input_kwargs = input_kwargs
|
|
|
|
# Format original model inputs into the format expected by exported ONNX model.
|
|
onnx_format_args = export_output.adapt_torch_inputs_to_onnx(
|
|
*input_args, **input_kwargs
|
|
)
|
|
|
|
ref_outputs = export_output.adapt_torch_outputs_to_onnx(
|
|
ref_model(*ref_input_args, **ref_input_kwargs)
|
|
)
|
|
ort_outputs = run_ort(export_output, onnx_format_args)
|
|
if len(ref_outputs) != len(ort_outputs):
|
|
raise AssertionError(
|
|
f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
|
|
)
|
|
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
|
torch.testing.assert_close(
|
|
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
|
|
)
|