mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Expose verification utilities (#148603)
Expose verification utilities to public documentation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148603 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
c36ac16da1
commit
ebabd0efdd
|
|
@ -88,6 +88,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
|
|||
:hidden:
|
||||
|
||||
onnx_dynamo
|
||||
onnx_verification
|
||||
onnx_dynamo_onnxruntime_backend
|
||||
onnx_torchscript
|
||||
|
||||
|
|
@ -99,6 +100,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
|
|||
.. py:module:: torch.onnx.symbolic_helper
|
||||
.. py:module:: torch.onnx.symbolic_opset10
|
||||
.. py:module:: torch.onnx.symbolic_opset11
|
||||
.. py:module:: torch.onnx.symbolic_opset12
|
||||
.. py:module:: torch.onnx.symbolic_opset13
|
||||
.. py:module:: torch.onnx.symbolic_opset14
|
||||
.. py:module:: torch.onnx.symbolic_opset15
|
||||
|
|
@ -111,5 +113,3 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
|
|||
.. py:module:: torch.onnx.symbolic_opset8
|
||||
.. py:module:: torch.onnx.symbolic_opset9
|
||||
.. py:module:: torch.onnx.utils
|
||||
.. py:module:: torch.onnx.verification
|
||||
.. py:module:: torch.onnx.symbolic_opset12
|
||||
|
|
@ -701,7 +701,6 @@ Functions
|
|||
.. autofunction:: unregister_custom_op_symbolic
|
||||
.. autofunction:: select_model_mode_for_export
|
||||
.. autofunction:: is_in_onnx_export
|
||||
.. autofunction:: torch.onnx.verification.find_mismatch
|
||||
|
||||
Classes
|
||||
^^^^^^^
|
||||
|
|
@ -712,5 +711,3 @@ Classes
|
|||
:template: classtemplate.rst
|
||||
|
||||
JitScalarType
|
||||
verification.GraphInfo
|
||||
verification.VerificationOptions
|
||||
|
|
|
|||
26
docs/source/onnx_verification.rst
Normal file
26
docs/source/onnx_verification.rst
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
torch.onnx.verification
|
||||
=======================
|
||||
|
||||
.. automodule:: torch.onnx.verification
|
||||
|
||||
.. autofunction:: verify_onnx_program
|
||||
|
||||
.. autoclass:: VerificationInfo
|
||||
:members:
|
||||
|
||||
.. autofunction:: verify
|
||||
|
||||
Deprecated
|
||||
----------
|
||||
|
||||
The following classes and functions are deprecated.
|
||||
|
||||
.. Some deprecated members are not publicly shown
|
||||
.. py:class:: check_export_model_diff
|
||||
.. py:class:: GraphInfo
|
||||
.. py:class:: GraphInfoPrettyPrinter
|
||||
.. py:class:: OnnxBackend
|
||||
.. py:class:: OnnxTestCaseRepro
|
||||
.. py:class:: VerificationOptions
|
||||
.. py:function:: find_mismatch
|
||||
.. py:function:: verify_aten_graph
|
||||
|
|
@ -26,6 +26,27 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@dataclasses.dataclass
|
||||
class VerificationInfo:
|
||||
"""Verification information for a value in the ONNX program.
|
||||
|
||||
This class contains the maximum absolute difference, maximum relative difference,
|
||||
and histograms of absolute and relative differences between the expected and actual
|
||||
values. It also includes the expected and actual data types.
|
||||
|
||||
The histograms are represented as tuples of tensors, where the first tensor is the
|
||||
histogram counts and the second tensor is the bin edges.
|
||||
|
||||
Attributes:
|
||||
name: The name of the value (output or intermediate).
|
||||
max_abs_diff: The maximum absolute difference between the expected and actual values.
|
||||
max_rel_diff: The maximum relative difference between the expected and actual values.
|
||||
abs_diff_hist: A tuple of tensors representing the histogram of absolute differences.
|
||||
The first tensor is the histogram counts and the second tensor is the bin edges.
|
||||
rel_diff_hist: A tuple of tensors representing the histogram of relative differences.
|
||||
The first tensor is the histogram counts and the second tensor is the bin edges.
|
||||
expected_dtype: The data type of the expected value.
|
||||
actual_dtype: The data type of the actual value.
|
||||
"""
|
||||
|
||||
name: str
|
||||
max_abs_diff: float
|
||||
max_rel_diff: float
|
||||
|
|
@ -40,8 +61,8 @@ class VerificationInfo:
|
|||
def from_tensors(
|
||||
cls,
|
||||
name: str,
|
||||
expected: torch.Tensor | int | float | bool,
|
||||
actual: torch.Tensor | int | float | bool,
|
||||
expected: torch.Tensor | torch.types.Number,
|
||||
actual: torch.Tensor | torch.types.Number,
|
||||
) -> VerificationInfo:
|
||||
"""Create a VerificationInfo object from two tensors.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,23 @@
|
|||
# mypy: allow-untyped-defs
|
||||
"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
|
||||
|
||||
ONNX Runtime is required, and is used as the ONNX backend for export verification.
|
||||
"""
|
||||
"""The ONNX verification module provides a set of tools to verify the correctness of ONNX models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OnnxBackend",
|
||||
"VerificationOptions",
|
||||
"verify",
|
||||
"check_export_model_diff",
|
||||
"VerificationInfo",
|
||||
"verify_onnx_program",
|
||||
"GraphInfo",
|
||||
"GraphInfoPrettyPrinter",
|
||||
"OnnxTestCaseRepro",
|
||||
"find_mismatch",
|
||||
"verify_aten_graph",
|
||||
]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
|
|
@ -31,9 +43,20 @@ from torch import _C
|
|||
from torch.onnx import _constants, _experimental, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import onnx_proto_utils
|
||||
from torch.onnx._internal.exporter._verification import (
|
||||
VerificationInfo,
|
||||
verify_onnx_program,
|
||||
)
|
||||
from torch.types import Number
|
||||
|
||||
|
||||
# TODO: Update deprecation messages to recommend the new classes
|
||||
|
||||
VerificationInfo.__module__ = "torch.onnx.verification"
|
||||
verify_onnx_program.__module__ = "torch.onnx.verification"
|
||||
|
||||
# Everything below are deprecated ##############################################
|
||||
|
||||
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
||||
|
||||
_NumericType = Union[Number, torch.Tensor, np.ndarray]
|
||||
|
|
@ -811,24 +834,22 @@ def verify(
|
|||
``ONNXProgram`` to test the ONNX model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
|
||||
input_args (tuple): See :func:`torch.onnx.export`.
|
||||
input_kwargs (dict): See :func:`torch.onnx.export`.
|
||||
do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
|
||||
dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
|
||||
input_names (list, optional): See :func:`torch.onnx.export`.
|
||||
output_names (list, optional): See :func:`torch.onnx.export`.
|
||||
training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
|
||||
opset_version (int, optional): See :func:`torch.onnx.export`.
|
||||
keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
|
||||
verbose (bool, optional): See :func:`torch.onnx.export`.
|
||||
fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
|
||||
use_external_data (bool, optional): Explicitly specify whether to export the
|
||||
model with external data.
|
||||
additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
|
||||
input arguments to test. Currently only *args are supported.
|
||||
options (_VerificationOptions, optional): A _VerificationOptions object that
|
||||
controls the verification behavior.
|
||||
model: See :func:`torch.onnx.export`.
|
||||
input_args: See :func:`torch.onnx.export`.
|
||||
input_kwargs: See :func:`torch.onnx.export`.
|
||||
do_constant_folding: See :func:`torch.onnx.export`.
|
||||
dynamic_axes: See :func:`torch.onnx.export`.
|
||||
input_names: See :func:`torch.onnx.export`.
|
||||
output_names: See :func:`torch.onnx.export`.
|
||||
training: See :func:`torch.onnx.export`.
|
||||
opset_version: See :func:`torch.onnx.export`.
|
||||
keep_initializers_as_inputs: See :func:`torch.onnx.export`.
|
||||
verbose: See :func:`torch.onnx.export`.
|
||||
fixed_batch_size: Legacy argument, used only by rnn test cases.
|
||||
use_external_data: Explicitly specify whether to export the model with external data.
|
||||
additional_test_inputs: List of tuples. Each tuple is a group of
|
||||
input arguments to test. Currently only ``*args`` are supported.
|
||||
options: A VerificationOptions object that controls the verification behavior.
|
||||
|
||||
Raises:
|
||||
AssertionError: if outputs from ONNX model and PyTorch model are not
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user