mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091 Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
730 lines
26 KiB
Python
730 lines
26 KiB
Python
"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
|
|
|
|
ONNX Runtime is required, and is used as the ONNX backend for export verification.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import copy
|
|
import difflib
|
|
import io
|
|
import itertools
|
|
import os
|
|
import tempfile
|
|
import warnings
|
|
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._C._onnx as _C_onnx
|
|
from torch import _C
|
|
from torch.onnx import _constants, _experimental, utils
|
|
from torch.onnx._globals import GLOBALS
|
|
from torch.onnx._internal import _beartype
|
|
from torch.types import Number
|
|
|
|
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
|
|
|
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
|
|
|
|
|
@_beartype.beartype
|
|
def _flatten_tuples(elem):
|
|
flattened = []
|
|
for t in elem:
|
|
if isinstance(t, tuple):
|
|
flattened.extend(_flatten_tuples(t))
|
|
else:
|
|
flattened.append(t)
|
|
return flattened
|
|
|
|
|
|
# TODO(justinchuby): Add type checking by narrowing down the return type when input is None
|
|
def _to_numpy(elem) -> Union[list, np.ndarray]:
|
|
if isinstance(elem, torch.Tensor):
|
|
if elem.requires_grad:
|
|
return elem.detach().cpu().numpy()
|
|
else:
|
|
return elem.cpu().numpy()
|
|
elif isinstance(elem, (list, tuple)):
|
|
return [_to_numpy(inp) for inp in elem]
|
|
elif isinstance(elem, (bool, int, float)):
|
|
return np.array(elem)
|
|
elif isinstance(elem, dict):
|
|
flattened = []
|
|
for k in elem:
|
|
flattened.extend([_to_numpy(k), _to_numpy(elem[k])])
|
|
return flattened
|
|
return elem
|
|
|
|
|
|
@_beartype.beartype
|
|
def _inline_flatten_list(inputs, res_list) -> list:
|
|
for i in inputs:
|
|
res_list.append(i) if not isinstance(
|
|
i, (list, tuple)
|
|
) else _inline_flatten_list(i, res_list)
|
|
return res_list
|
|
|
|
|
|
@_beartype.beartype
|
|
def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
|
|
value_unpacked = []
|
|
for value in values:
|
|
value_unpacked.extend(
|
|
utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted)
|
|
)
|
|
return [_to_numpy(v) for v in value_unpacked]
|
|
|
|
|
|
@_beartype.beartype
|
|
def _run_ort(ort_session, inputs):
|
|
kw_inputs = {}
|
|
if inputs and isinstance(inputs[-1], dict):
|
|
kw_inputs = inputs[-1]
|
|
inputs = inputs[:-1]
|
|
inputs = _unpack_to_numpy(_flatten_tuples(inputs))
|
|
ort_inputs = {}
|
|
for input_name, input in kw_inputs.items():
|
|
ort_inputs[input_name] = _to_numpy(input)
|
|
inputs = _to_numpy(inputs)
|
|
ort_session_inputs = ort_session.get_inputs()
|
|
for i, input in enumerate(inputs):
|
|
if i == len(ort_session_inputs) or ort_session_inputs[i].name in ort_inputs:
|
|
raise ValueError(
|
|
f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}"
|
|
)
|
|
ort_inputs[ort_session_inputs[i].name] = input
|
|
ort_outs = ort_session.run(None, ort_inputs)
|
|
return ort_outs
|
|
|
|
|
|
@_beartype.beartype
|
|
def _ort_session(
|
|
model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
|
|
):
|
|
try:
|
|
import onnxruntime # type: ignore[import]
|
|
except ImportError:
|
|
raise ImportError("onnxruntime is required for export verification.")
|
|
|
|
if ort_providers is None:
|
|
ort_providers = _ORT_PROVIDERS
|
|
|
|
session_options = onnxruntime.SessionOptions()
|
|
# suppress ort warnings.
|
|
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
|
session_options.log_severity_level = 3
|
|
ort_session = onnxruntime.InferenceSession(
|
|
model if isinstance(model, str) else model.getvalue(),
|
|
session_options,
|
|
providers=ort_providers,
|
|
)
|
|
return ort_session
|
|
|
|
|
|
@_beartype.beartype
|
|
def _compare_ort_pytorch_outputs(
|
|
ort_outs: Union[Sequence[_NumericType], Sequence],
|
|
pt_outs: Optional[Union[_NumericType, Sequence[_NumericType], Sequence, Dict]],
|
|
rtol: float,
|
|
atol: float,
|
|
check_shape: bool,
|
|
check_dtype: bool,
|
|
ignore_none: bool,
|
|
acceptable_error_percentage: Optional[float],
|
|
):
|
|
"""
|
|
Compare ONNX Runtime and PyTorch outputs.
|
|
|
|
Args:
|
|
ort_outs: outputs from ONNX Runtime.
|
|
pt_outs: outputs from PyTorch.
|
|
rtol: relative tolerance in comparison between ONNX and PyTorch outputs.
|
|
atol: absolute tolerance in comparison between ONNX and PyTorch outputs.
|
|
ignore_none: Whether to ignore None type in
|
|
torch output, which is usually the case with tracing. Set this to False, if
|
|
torch output should keep None type, which is usually the case with exporting
|
|
ScriptModules.
|
|
acceptable_error_percentage: acceptable percentage of element mismatches in comparison.
|
|
It should be a float of value between 0.0 and 1.0.
|
|
|
|
Raises:
|
|
AssertionError: if outputs from ONNX model and PyTorch model are not
|
|
equal up to specified precision.
|
|
ValueError: if arguments provided are invalid.
|
|
"""
|
|
if ignore_none:
|
|
# torch.jit._flatten filters None type
|
|
pt_outs, _ = torch.jit._flatten(pt_outs)
|
|
else:
|
|
pt_outs = _inline_flatten_list([pt_outs], [])
|
|
pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
|
|
ort_outs = _inline_flatten_list(ort_outs, [])
|
|
assert len(ort_outs) == len(
|
|
pt_outs_np
|
|
), f"Number of outputs differ ONNX runtime: ({len(ort_outs)}) PyTorch: ({len(pt_outs_np)})"
|
|
if acceptable_error_percentage and (
|
|
acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
|
|
):
|
|
raise ValueError(
|
|
"If set, acceptable_error_percentage should be between 0.0 and 1.0"
|
|
)
|
|
|
|
for ort_out, pt_out in zip(ort_outs, pt_outs_np):
|
|
try:
|
|
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
|
|
if not check_shape:
|
|
# Allow different but broadcastable output shapes.
|
|
ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
|
|
torch.testing.assert_close(
|
|
ort_out,
|
|
pt_out,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
check_dtype=check_dtype,
|
|
equal_nan=True,
|
|
)
|
|
except AssertionError as e:
|
|
if acceptable_error_percentage:
|
|
error_percentage = 1 - np.sum(
|
|
np.isclose(ort_out, pt_out, rtol=rtol, atol=atol)
|
|
) / np.prod(ort_out.shape)
|
|
if error_percentage <= acceptable_error_percentage:
|
|
warnings.warn(
|
|
f"Suppressed AssertionError:\n{e}.\n"
|
|
f"Error percentage {error_percentage} "
|
|
f"within acceptable range {acceptable_error_percentage}."
|
|
)
|
|
continue
|
|
raise
|
|
|
|
|
|
@_beartype.beartype
|
|
def _prepare_input_for_pytorch(args, kwargs):
|
|
"""Prepare input for PyTorch model execution.
|
|
|
|
Any future changes/formatting to the input before dispatching to the PyTorch
|
|
model should be made in this function.
|
|
|
|
Args:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
|
|
Returns:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
"""
|
|
if isinstance(args, (torch.Tensor, dict)):
|
|
args = (args,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
args = copy.deepcopy(args)
|
|
if kwargs:
|
|
kwargs = copy.deepcopy(kwargs)
|
|
else:
|
|
kwargs = {}
|
|
return args, kwargs
|
|
|
|
|
|
@_beartype.beartype
|
|
def _prepare_input_for_export(args, kwargs):
|
|
"""Prepare input for ONNX model export.
|
|
|
|
Any future changes/formatting to the input before dispatching to the
|
|
:func:`torch.onnx.export` api should be made in this function.
|
|
|
|
Args:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
|
|
Returns:
|
|
onnx_inputs: positional arguments for ONNX model export, as `args` in
|
|
:func:`torch.onnx.export`.
|
|
"""
|
|
args, kwargs = _prepare_input_for_pytorch(args, kwargs)
|
|
if not kwargs and isinstance(args[-1], dict):
|
|
onnx_inputs = args + ({},)
|
|
elif kwargs:
|
|
onnx_inputs = args + (kwargs,)
|
|
else:
|
|
onnx_inputs = args
|
|
return onnx_inputs
|
|
|
|
|
|
@_beartype.beartype
|
|
def _prepare_input_for_ort(args, kwargs, remained_onnx_input_idx, flatten):
|
|
"""Prepare input for ONNX model execution in ONNX Runtime.
|
|
|
|
Any future changes/formatting to the input before dispatching to the ONNX Runtime
|
|
InferenceSession run should be made in this function.
|
|
|
|
Args:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
|
|
Returns:
|
|
onnx_inputs: positional arguments for ONNX model execution in ONNX Runtime.
|
|
"""
|
|
onnx_inputs = _prepare_input_for_export(args, kwargs)
|
|
if flatten:
|
|
onnx_inputs, _ = torch.jit._flatten(onnx_inputs)
|
|
elif onnx_inputs and onnx_inputs[-1] == {}:
|
|
# Handle empty kwargs (normally removed by flatten).
|
|
onnx_inputs = onnx_inputs[:-1]
|
|
if remained_onnx_input_idx is not None:
|
|
return [onnx_inputs[i] for i in remained_onnx_input_idx]
|
|
else:
|
|
return onnx_inputs
|
|
|
|
|
|
@_beartype.beartype
|
|
def _try_clone_model(model):
|
|
"""Used for preserving original model in case forward mutates model states."""
|
|
try:
|
|
return copy.deepcopy(model)
|
|
except Exception:
|
|
warnings.warn(
|
|
"Failed to clone model. Model state might be mutated during verification."
|
|
)
|
|
return model
|
|
|
|
|
|
@_beartype.beartype
|
|
def _compare_ort_pytorch_model(
|
|
model,
|
|
ort_session,
|
|
input_args,
|
|
input_kwargs,
|
|
additional_test_inputs,
|
|
remained_onnx_input_idx,
|
|
flatten,
|
|
ignore_none,
|
|
rtol,
|
|
atol,
|
|
check_shape,
|
|
check_dtype,
|
|
acceptable_error_percentage: Optional[float],
|
|
):
|
|
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
|
|
|
|
ONNX Runtime is used for model execution backend for ONNX model.
|
|
|
|
Raises:
|
|
AssertionError: if outputs from ONNX model and PyTorch model are not
|
|
equal up to specified precision.
|
|
"""
|
|
|
|
@_beartype.beartype
|
|
def compare_ort_pytorch_model_with_input(input_args, input_kwargs):
|
|
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
|
|
# TODO: remove this and treat mutating model separately. See #77679
|
|
model_copy = _try_clone_model(model)
|
|
pt_outs = model_copy(*pt_args, **pt_kwargs)
|
|
|
|
ort_inputs = _prepare_input_for_ort(
|
|
input_args, input_kwargs, remained_onnx_input_idx, flatten
|
|
)
|
|
ort_outs = _run_ort(ort_session, ort_inputs)
|
|
|
|
_compare_ort_pytorch_outputs(
|
|
ort_outs=ort_outs,
|
|
pt_outs=pt_outs,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
check_shape=check_shape,
|
|
check_dtype=check_dtype,
|
|
ignore_none=ignore_none,
|
|
acceptable_error_percentage=acceptable_error_percentage,
|
|
)
|
|
|
|
compare_ort_pytorch_model_with_input(input_args, input_kwargs)
|
|
|
|
if additional_test_inputs:
|
|
for test_input_args in additional_test_inputs:
|
|
compare_ort_pytorch_model_with_input(test_input_args, {})
|
|
|
|
|
|
class _GraphDiff:
|
|
"""A class to represent the difference between two graphs."""
|
|
|
|
@_beartype.beartype
|
|
def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
|
|
"""Construct a _GraphDiff object.
|
|
|
|
Args:
|
|
graph_a (_C.Graph): First graph to compare.
|
|
graph_b (_C.Graph): Second graph to compare.
|
|
"""
|
|
self.graph_a = graph_a
|
|
self.graph_b = graph_b
|
|
|
|
@_beartype.beartype
|
|
def __str__(self):
|
|
"""See function :func:`diff_report`."""
|
|
return self.diff_report()
|
|
|
|
@_beartype.beartype
|
|
def _indent(self, lines: str) -> str:
|
|
return "\n".join(["\t" + line for line in lines.splitlines()])
|
|
|
|
@_beartype.beartype
|
|
def diff_report(self) -> str:
|
|
"""Return a string representation of the graph difference.
|
|
|
|
The report shows the first pair of nodes that diverges. It also shows the source
|
|
location of the pair of nodes.
|
|
|
|
Returns:
|
|
graph_diff_report (str): A string representation of the graph difference.
|
|
"""
|
|
graph_a = self.graph_a
|
|
graph_b = self.graph_b
|
|
|
|
graph_a_str = str(graph_a)
|
|
graph_b_str = str(graph_b)
|
|
|
|
if graph_a_str == graph_b_str:
|
|
return ""
|
|
|
|
graph_diff = difflib.ndiff(
|
|
graph_a_str.splitlines(True), graph_b_str.splitlines(True)
|
|
)
|
|
graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))]
|
|
|
|
for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()):
|
|
if str(node_a) != str(node_b):
|
|
graph_diff_report.append("First diverging operator:")
|
|
node_diff = difflib.ndiff(
|
|
str(node_a).splitlines(True), str(node_b).splitlines(True)
|
|
)
|
|
source_printout = ["node diff:", self._indent("".join(node_diff))]
|
|
|
|
stack_a = node_a.sourceRange() if node_a else None
|
|
if stack_a:
|
|
source_printout.extend(
|
|
["Former source location:", self._indent(str(stack_a))]
|
|
)
|
|
stack_b = node_b.sourceRange() if node_b else None
|
|
if stack_b:
|
|
source_printout.extend(
|
|
["Latter source location:", self._indent(str(stack_b))]
|
|
)
|
|
|
|
graph_diff_report.extend(source_printout)
|
|
|
|
break
|
|
|
|
return "\n".join(graph_diff_report)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _check_graph_diff(
|
|
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
|
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
|
|
export_options: _experimental.ExportOptions,
|
|
model_to_graph_func: Callable[
|
|
[
|
|
torch.nn.Module,
|
|
Tuple[Any, ...],
|
|
Mapping[str, Any],
|
|
_experimental.ExportOptions,
|
|
],
|
|
_C.Graph,
|
|
],
|
|
) -> str:
|
|
"""Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`.
|
|
|
|
Args:
|
|
model: See :func:`check_export_model_diff`.
|
|
test_input_groups: See :func:`check_export_model_diff`.
|
|
export_options: See :func:`check_export_model_diff`.
|
|
model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph.
|
|
|
|
Returns:
|
|
graph_diff_report (str): A string representation of the graph difference.
|
|
"""
|
|
if len(test_input_groups) < 2:
|
|
raise ValueError("Need at least two groups of test inputs to compare.")
|
|
|
|
ref_jit_graph = None
|
|
for args, kwargs in test_input_groups:
|
|
jit_graph = model_to_graph_func(model, args, kwargs, export_options)
|
|
if ref_jit_graph is None:
|
|
ref_jit_graph = jit_graph
|
|
continue
|
|
|
|
graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report()
|
|
if graph_diff_report:
|
|
return graph_diff_report
|
|
return ""
|
|
|
|
|
|
@_beartype.beartype
|
|
def _traced_graph_from_model(
|
|
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Mapping[str, Any],
|
|
export_options: _experimental.ExportOptions,
|
|
) -> _C.Graph:
|
|
"""As part of the ONNX export steps, create a traced JIT graph from a PyTorch model.
|
|
|
|
Args:
|
|
model: See :func:`check_export_model_diff`.
|
|
args: See :func:`check_export_model_diff`.
|
|
kwargs: See :func:`check_export_model_diff`.
|
|
export_options: See :func:`check_export_model_diff`.
|
|
|
|
Returns:
|
|
jit_graph (_C.Graph): A traced JIT graph.
|
|
"""
|
|
training = export_options.training
|
|
verbose = export_options.verbose
|
|
|
|
with utils.exporter_context(model, training, verbose):
|
|
export_inputs = _prepare_input_for_export(args, kwargs)
|
|
model = utils._pre_trace_quant_model(model, export_inputs)
|
|
jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs)
|
|
return jit_graph
|
|
|
|
|
|
@_beartype.beartype
|
|
def _onnx_graph_from_model(
|
|
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Mapping[str, Any],
|
|
export_options: _experimental.ExportOptions,
|
|
) -> _C.Graph:
|
|
"""As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model.
|
|
|
|
Args:
|
|
model: See :func:`check_export_model_diff`.
|
|
args: See :func:`check_export_model_diff`.
|
|
kwargs: See :func:`check_export_model_diff`.
|
|
export_options: See :func:`check_export_model_diff`.
|
|
|
|
Returns:
|
|
onnx_graph (_C.Graph): An ONNX JIT graph.
|
|
"""
|
|
# TODO: refactor utils.py to remove duplicated code of context setup. See #78834
|
|
opset_version = export_options.opset_version
|
|
operator_export_type = export_options.operator_export_type
|
|
export_modules_as_functions = export_options.export_modules_as_functions
|
|
training = export_options.training
|
|
verbose = export_options.verbose
|
|
dynamic_axes = export_options.dynamic_axes
|
|
input_names = export_options.input_names
|
|
output_names = export_options.output_names
|
|
|
|
if opset_version is None:
|
|
opset_version = _constants.onnx_default_opset
|
|
|
|
utils._setup_trace_module_map(model, export_modules_as_functions)
|
|
|
|
if not operator_export_type:
|
|
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
|
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
|
else:
|
|
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
|
|
|
|
GLOBALS.export_onnx_opset_version = opset_version
|
|
GLOBALS.operator_export_type = operator_export_type
|
|
|
|
with utils.exporter_context(model, training, verbose):
|
|
do_constant_folding = utils._decide_constant_folding(
|
|
export_options.do_constant_folding, operator_export_type, training
|
|
)
|
|
|
|
if dynamic_axes is None:
|
|
dynamic_axes = {}
|
|
utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
|
|
|
|
export_inputs = _prepare_input_for_export(args, kwargs)
|
|
export_inputs = utils._decide_input_format(model, export_inputs)
|
|
onnx_graph, _, _ = utils._model_to_graph(
|
|
model,
|
|
export_inputs,
|
|
verbose,
|
|
input_names,
|
|
output_names,
|
|
operator_export_type,
|
|
do_constant_folding,
|
|
training=training,
|
|
dynamic_axes=dynamic_axes,
|
|
)
|
|
|
|
return onnx_graph
|
|
|
|
|
|
@_beartype.beartype
|
|
def check_export_model_diff(
|
|
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
|
test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
|
|
export_options: Optional[_experimental.ExportOptions] = None,
|
|
) -> str:
|
|
"""Verify exported model discrepancy between different groups of inputs.
|
|
|
|
A graph is exported for each group of inputs. The exported graphs are then compared
|
|
to each other, and discrepancies of first pair of nodes are reported. This function
|
|
first checks the jit graph. If no discrepancies were found, it then checks the onnx
|
|
graph.
|
|
|
|
Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless
|
|
of the inputs used for exporting. A discrepancy implies the graph exported is
|
|
not accurate when run on other groups of inputs, which will typically results in
|
|
runtime errors or mismatching output.
|
|
|
|
Args:
|
|
model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported.
|
|
test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence
|
|
of input groups to be used to export the model. Each input group is a pair of
|
|
(args, kwargs).
|
|
export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions
|
|
object that controls the export behavior.
|
|
|
|
Returns:
|
|
str: A string containing the diff of the exported models.
|
|
"""
|
|
export_options = (
|
|
_experimental.ExportOptions() if export_options is None else export_options
|
|
)
|
|
|
|
jit_diff_report = _check_graph_diff(
|
|
model, test_input_groups, export_options, _traced_graph_from_model
|
|
)
|
|
if jit_diff_report:
|
|
return jit_diff_report
|
|
|
|
return _check_graph_diff(
|
|
model, test_input_groups, export_options, _onnx_graph_from_model
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def verify(
|
|
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
|
input_args: Union[torch.Tensor, Tuple[Any, ...]],
|
|
input_kwargs: Optional[Mapping[str, Any]] = None,
|
|
do_constant_folding: bool = True,
|
|
dynamic_axes: Optional[
|
|
Mapping[str, Union[Mapping[int, str], Mapping[str, Sequence[int]]]]
|
|
] = None,
|
|
input_names: Optional[Sequence[str]] = None,
|
|
output_names: Optional[Sequence[str]] = None,
|
|
training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL,
|
|
opset_version: Optional[int] = None,
|
|
keep_initializers_as_inputs: bool = True,
|
|
verbose: bool = False,
|
|
fixed_batch_size: bool = False,
|
|
use_external_data: bool = False,
|
|
additional_test_inputs: Optional[
|
|
Sequence[Union[torch.Tensor, Tuple[Any, ...]]]
|
|
] = None,
|
|
remained_onnx_input_idx: Optional[Sequence[int]] = None,
|
|
flatten: bool = True,
|
|
ignore_none: bool = True,
|
|
check_shape: bool = True,
|
|
check_dtype: bool = True,
|
|
ort_providers: Sequence[str] = _ORT_PROVIDERS,
|
|
rtol: float = 0.001,
|
|
atol: float = 1e-7,
|
|
acceptable_error_percentage: Optional[float] = None,
|
|
**_,
|
|
):
|
|
"""Verify model export to ONNX with ONNX Runtime.
|
|
|
|
Args:
|
|
model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
|
|
input_args (tuple): See :func:`torch.onnx.export`.
|
|
input_kwargs (dict): See :func:`torch.onnx.export`.
|
|
do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
|
|
dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
|
|
input_names (list, optional): See :func:`torch.onnx.export`.
|
|
output_names (list, optional): See :func:`torch.onnx.export`.
|
|
training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
|
|
opset_version (int, optional): See :func:`torch.onnx.export`.
|
|
keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
|
|
verbose (bool, optional): See :func:`torch.onnx.export`.
|
|
fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
|
|
use_external_data (bool, optional): Explicitly specify whether to export the
|
|
model with external data.
|
|
additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
|
|
input arguments to test. Currently only *args are supported.
|
|
remained_onnx_input_idx (list, optional): If provided, only the specified inputs
|
|
will be passed to the ONNX model. Supply a list when there are unused inputs
|
|
in the model. Since unused inputs will be removed in the exported ONNX
|
|
model, supplying all inputs will cause an error on unexpected inputs.
|
|
This parameter tells the verifier which inputs to pass into the ONNX model.
|
|
flatten (bool, optional): Default True. If True, unpack nested list/tuple/dict
|
|
inputs into a flattened list of Tensors for ONNX. Set this to False if nested
|
|
structures are to be preserved for ONNX, which is usually the case with
|
|
exporting ScriptModules.
|
|
ignore_none (bool, optional): Whether to ignore None type in
|
|
torch output, which is usually the case with tracing. Set this to False, if
|
|
torch output should keep None type, which is usually the case with exporting
|
|
ScriptModules. Default to True.
|
|
check_shape (bool, optional): Whether to check the shapes between
|
|
PyTorch and ONNX Runtime outputs are exactly the same. Set this to False to allow
|
|
output shape broadcasting. Default to True.
|
|
check_dtype (bool, optional): Whether to check the dtypes between
|
|
PyTorch and ONNX Runtime outputs are consistent. Default to True.
|
|
ort_providers (sequence, optional): ONNX Runtime providers to use.
|
|
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
|
|
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
|
|
acceptable_error_percentage (float, optional): acceptable percentage of element mismatches in comparison.
|
|
It should be a float of value between 0.0 and 1.0.
|
|
|
|
Raises:
|
|
AssertionError: if outputs from ONNX model and PyTorch model are not
|
|
equal up to specified precision.
|
|
ValueError: if arguments provided are invalid.
|
|
"""
|
|
if training == torch.onnx.TrainingMode.TRAINING:
|
|
model.train()
|
|
elif training == torch.onnx.TrainingMode.EVAL:
|
|
model.eval()
|
|
with torch.no_grad(), contextlib.ExitStack() as stack:
|
|
model_f: Union[str, io.BytesIO] = io.BytesIO()
|
|
if use_external_data:
|
|
tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
|
|
model_f = os.path.join(tmpdir_path, "model.onnx")
|
|
|
|
inputs_for_export = _prepare_input_for_export(input_args, input_kwargs)
|
|
|
|
# TODO(#77679): remove this and treat mutating model separately.
|
|
model_copy = _try_clone_model(model)
|
|
utils._export(
|
|
model,
|
|
inputs_for_export,
|
|
model_f,
|
|
opset_version=opset_version,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size,
|
|
training=training,
|
|
verbose=verbose,
|
|
)
|
|
|
|
ort_session = _ort_session(model_f, ort_providers)
|
|
|
|
_compare_ort_pytorch_model(
|
|
model=model_copy,
|
|
ort_session=ort_session,
|
|
input_args=input_args,
|
|
input_kwargs=input_kwargs,
|
|
additional_test_inputs=additional_test_inputs,
|
|
remained_onnx_input_idx=remained_onnx_input_idx,
|
|
flatten=flatten,
|
|
ignore_none=ignore_none,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
check_shape=check_shape,
|
|
check_dtype=check_dtype,
|
|
acceptable_error_percentage=acceptable_error_percentage,
|
|
)
|