[ONNX] Expose verification utilities (#148603)

Expose verification utilities to public documentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148603
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-03-18 02:10:34 +00:00 committed by PyTorch MergeBot
parent c36ac16da1
commit ebabd0efdd
5 changed files with 94 additions and 29 deletions

View File

@ -88,6 +88,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
:hidden:
onnx_dynamo
onnx_verification
onnx_dynamo_onnxruntime_backend
onnx_torchscript
@ -99,6 +100,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
.. py:module:: torch.onnx.symbolic_helper
.. py:module:: torch.onnx.symbolic_opset10
.. py:module:: torch.onnx.symbolic_opset11
.. py:module:: torch.onnx.symbolic_opset12
.. py:module:: torch.onnx.symbolic_opset13
.. py:module:: torch.onnx.symbolic_opset14
.. py:module:: torch.onnx.symbolic_opset15
@ -111,5 +113,3 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
.. py:module:: torch.onnx.symbolic_opset8
.. py:module:: torch.onnx.symbolic_opset9
.. py:module:: torch.onnx.utils
.. py:module:: torch.onnx.verification
.. py:module:: torch.onnx.symbolic_opset12

View File

@ -701,7 +701,6 @@ Functions
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
.. autofunction:: is_in_onnx_export
.. autofunction:: torch.onnx.verification.find_mismatch
Classes
^^^^^^^
@ -712,5 +711,3 @@ Classes
:template: classtemplate.rst
JitScalarType
verification.GraphInfo
verification.VerificationOptions

View File

@ -0,0 +1,26 @@
torch.onnx.verification
=======================
.. automodule:: torch.onnx.verification
.. autofunction:: verify_onnx_program
.. autoclass:: VerificationInfo
:members:
.. autofunction:: verify
Deprecated
----------
The following classes and functions are deprecated.
.. Some deprecated members are not publicly shown
.. py:class:: check_export_model_diff
.. py:class:: GraphInfo
.. py:class:: GraphInfoPrettyPrinter
.. py:class:: OnnxBackend
.. py:class:: OnnxTestCaseRepro
.. py:class:: VerificationOptions
.. py:function:: find_mismatch
.. py:function:: verify_aten_graph

View File

@ -26,6 +26,27 @@ logger = logging.getLogger(__name__)
@dataclasses.dataclass
class VerificationInfo:
"""Verification information for a value in the ONNX program.
This class contains the maximum absolute difference, maximum relative difference,
and histograms of absolute and relative differences between the expected and actual
values. It also includes the expected and actual data types.
The histograms are represented as tuples of tensors, where the first tensor is the
histogram counts and the second tensor is the bin edges.
Attributes:
name: The name of the value (output or intermediate).
max_abs_diff: The maximum absolute difference between the expected and actual values.
max_rel_diff: The maximum relative difference between the expected and actual values.
abs_diff_hist: A tuple of tensors representing the histogram of absolute differences.
The first tensor is the histogram counts and the second tensor is the bin edges.
rel_diff_hist: A tuple of tensors representing the histogram of relative differences.
The first tensor is the histogram counts and the second tensor is the bin edges.
expected_dtype: The data type of the expected value.
actual_dtype: The data type of the actual value.
"""
name: str
max_abs_diff: float
max_rel_diff: float
@ -40,8 +61,8 @@ class VerificationInfo:
def from_tensors(
cls,
name: str,
expected: torch.Tensor | int | float | bool,
actual: torch.Tensor | int | float | bool,
expected: torch.Tensor | torch.types.Number,
actual: torch.Tensor | torch.types.Number,
) -> VerificationInfo:
"""Create a VerificationInfo object from two tensors.

View File

@ -1,11 +1,23 @@
# mypy: allow-untyped-defs
"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
ONNX Runtime is required, and is used as the ONNX backend for export verification.
"""
"""The ONNX verification module provides a set of tools to verify the correctness of ONNX models."""
from __future__ import annotations
__all__ = [
"OnnxBackend",
"VerificationOptions",
"verify",
"check_export_model_diff",
"VerificationInfo",
"verify_onnx_program",
"GraphInfo",
"GraphInfoPrettyPrinter",
"OnnxTestCaseRepro",
"find_mismatch",
"verify_aten_graph",
]
import contextlib
import copy
import dataclasses
@ -31,9 +43,20 @@ from torch import _C
from torch.onnx import _constants, _experimental, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import onnx_proto_utils
from torch.onnx._internal.exporter._verification import (
VerificationInfo,
verify_onnx_program,
)
from torch.types import Number
# TODO: Update deprecation messages to recommend the new classes
VerificationInfo.__module__ = "torch.onnx.verification"
verify_onnx_program.__module__ = "torch.onnx.verification"
# Everything below are deprecated ##############################################
_ORT_PROVIDERS = ("CPUExecutionProvider",)
_NumericType = Union[Number, torch.Tensor, np.ndarray]
@ -811,24 +834,22 @@ def verify(
``ONNXProgram`` to test the ONNX model.
Args:
model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
input_args (tuple): See :func:`torch.onnx.export`.
input_kwargs (dict): See :func:`torch.onnx.export`.
do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
input_names (list, optional): See :func:`torch.onnx.export`.
output_names (list, optional): See :func:`torch.onnx.export`.
training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
opset_version (int, optional): See :func:`torch.onnx.export`.
keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
verbose (bool, optional): See :func:`torch.onnx.export`.
fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
use_external_data (bool, optional): Explicitly specify whether to export the
model with external data.
additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
input arguments to test. Currently only *args are supported.
options (_VerificationOptions, optional): A _VerificationOptions object that
controls the verification behavior.
model: See :func:`torch.onnx.export`.
input_args: See :func:`torch.onnx.export`.
input_kwargs: See :func:`torch.onnx.export`.
do_constant_folding: See :func:`torch.onnx.export`.
dynamic_axes: See :func:`torch.onnx.export`.
input_names: See :func:`torch.onnx.export`.
output_names: See :func:`torch.onnx.export`.
training: See :func:`torch.onnx.export`.
opset_version: See :func:`torch.onnx.export`.
keep_initializers_as_inputs: See :func:`torch.onnx.export`.
verbose: See :func:`torch.onnx.export`.
fixed_batch_size: Legacy argument, used only by rnn test cases.
use_external_data: Explicitly specify whether to export the model with external data.
additional_test_inputs: List of tuples. Each tuple is a group of
input arguments to test. Currently only ``*args`` are supported.
options: A VerificationOptions object that controls the verification behavior.
Raises:
AssertionError: if outputs from ONNX model and PyTorch model are not