mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Documentation for torch.onnx.find_mismatch (#90728)
Doc preview: * `find_mismatch`: https://docs-preview.pytorch.org/90728/onnx.html#torch.onnx.verification.find_mismatch * `GraphInfo`: https://docs-preview.pytorch.org/90728/onnx.html#classes and https://docs-preview.pytorch.org/90728/generated/torch.onnx.verification.GraphInfo.html#torch.onnx.verification.GraphInfo * `VerificationOptions`: https://docs-preview.pytorch.org/90728/onnx.html#classes and https://docs-preview.pytorch.org/90728/generated/torch.onnx.verification.VerificationOptions.html#torch.onnx.verification.VerificationOptions Pull Request resolved: https://github.com/pytorch/pytorch/pull/90728 Approved by: https://github.com/titaiwangms, https://github.com/justinchuby
This commit is contained in:
parent
ed7885c254
commit
c537f5bee8
|
|
@ -706,6 +706,7 @@ Functions
|
|||
.. autofunction:: is_in_onnx_export
|
||||
.. autofunction:: enable_log
|
||||
.. autofunction:: disable_log
|
||||
.. autofunction:: torch.onnx.verification.find_mismatch
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
|
@ -716,3 +717,5 @@ Classes
|
|||
:template: classtemplate.rst
|
||||
|
||||
JitScalarType
|
||||
torch.onnx.verification.GraphInfo
|
||||
torch.onnx.verification.VerificationOptions
|
||||
|
|
|
|||
|
|
@ -681,6 +681,9 @@ def _onnx_graph_from_aten_graph(
|
|||
do_constant_folding = export_options.do_constant_folding
|
||||
opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
|
||||
|
||||
GLOBALS.export_onnx_opset_version = opset_version
|
||||
GLOBALS.operator_export_type = operator_export_type
|
||||
|
||||
do_constant_folding = utils._decide_constant_folding(
|
||||
do_constant_folding, operator_export_type, training
|
||||
)
|
||||
|
|
@ -816,7 +819,7 @@ def verify(
|
|||
] = None,
|
||||
input_names: Optional[Sequence[str]] = None,
|
||||
output_names: Optional[Sequence[str]] = None,
|
||||
training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL,
|
||||
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
|
||||
opset_version: Optional[int] = None,
|
||||
keep_initializers_as_inputs: bool = True,
|
||||
verbose: bool = False,
|
||||
|
|
@ -1197,6 +1200,8 @@ class OnnxTestCaseRepro:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class GraphInfo:
|
||||
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
|
||||
|
||||
graph: torch.Graph
|
||||
input_args: Tuple[Any, ...]
|
||||
params_dict: Dict[str, Any]
|
||||
|
|
@ -1235,8 +1240,8 @@ class GraphInfo:
|
|||
The id of the subgraph is shown under the node. The `GraphInfo` object for any
|
||||
subgraph can be retrieved by calling `graph_info.find_partition(id)`.
|
||||
|
||||
Example:
|
||||
```
|
||||
Example::
|
||||
|
||||
==================================== Tree: =====================================
|
||||
5 X __2 X __1 ✓
|
||||
id: | id: 0 | id: 00
|
||||
|
|
@ -1256,7 +1261,7 @@ class GraphInfo:
|
|||
['01', '110']
|
||||
============================= Mismatch node kinds: =============================
|
||||
{'aten::relu': 2}
|
||||
```
|
||||
|
||||
"""
|
||||
GraphInfoPrettyPrinter(self).pretty_print()
|
||||
|
||||
|
|
@ -1345,7 +1350,7 @@ class GraphInfo:
|
|||
) -> str:
|
||||
"""Export the subgraph to ONNX along with the input/output data for repro.
|
||||
|
||||
The repro directory will contain the following files:
|
||||
The repro directory will contain the following files::
|
||||
|
||||
dir
|
||||
├── test_<name>
|
||||
|
|
@ -1612,6 +1617,24 @@ class GraphInfo:
|
|||
def verify_export(
|
||||
self, options: VerificationOptions
|
||||
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
|
||||
"""
|
||||
Verify the export from TorchScript IR graph to ONNX.
|
||||
|
||||
Export the TorchScript IR graph to ONNX, with the inputs, parameters and export
|
||||
options recorded in this object. Then verify the exported ONNX graph against
|
||||
the original TorchScript IR graph under the provided verification options.
|
||||
|
||||
Args:
|
||||
options: The verification options.
|
||||
|
||||
Returns:
|
||||
error: The AssertionError raised during the verification. Returns None if no
|
||||
error is raised.
|
||||
onnx_graph: The exported ONNX graph in TorchScript IR format.
|
||||
onnx_outs: The outputs from running exported ONNX model under the onnx
|
||||
backend in `options`.
|
||||
pt_outs: The outputs from running the TorchScript IR graph.
|
||||
"""
|
||||
return verify_aten_graph(
|
||||
self.graph,
|
||||
input_args=self.input_args,
|
||||
|
|
@ -1625,6 +1648,16 @@ class GraphInfo:
|
|||
self,
|
||||
options: Optional[VerificationOptions] = None,
|
||||
):
|
||||
"""
|
||||
Find all mismatches between the TorchScript IR graph and the exported onnx model.
|
||||
|
||||
Binary searches the model graph to find the minimal subgraph that exhibits the
|
||||
mismatch. A `GraphInfo` object is created for each subgraph, recording the test
|
||||
inputs and export options, as well as the validation results.
|
||||
|
||||
Args:
|
||||
options: The verification options.
|
||||
"""
|
||||
self.clear()
|
||||
|
||||
if options is None:
|
||||
|
|
@ -1724,26 +1757,92 @@ def find_mismatch(
|
|||
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
||||
input_args: Tuple[Any, ...],
|
||||
do_constant_folding: bool = True,
|
||||
training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL,
|
||||
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
|
||||
opset_version: Optional[int] = None,
|
||||
keep_initializers_as_inputs: bool = True,
|
||||
verbose: bool = False,
|
||||
options: Optional[VerificationOptions] = None,
|
||||
) -> GraphInfo:
|
||||
r"""Find all mismatches between the original model and the exported model.
|
||||
TODO: Fill in docstring.
|
||||
|
||||
Experimental. The API is subject to change.
|
||||
|
||||
This tool helps debug the mismatch between the original PyTorch model and exported
|
||||
ONNX model. It binary searches the model graph to find the minimal subgraph that
|
||||
exhibits the mismatch.
|
||||
|
||||
Args:
|
||||
model: The model to be exported.
|
||||
input_args: The input arguments to the model.
|
||||
do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`.
|
||||
training: Same as `training` in :func:`torch.onnx.export`.
|
||||
opset_version: Same as `opset_version` in :func:`torch.onnx.export`.
|
||||
keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`.
|
||||
verbose: Same as `verbose` in :func:`torch.onnx.export`.
|
||||
options: The options for the mismatch verification.
|
||||
|
||||
Returns:
|
||||
A GraphInfo object that contains the mismatch information.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> import torch.onnx.verification
|
||||
>>> torch.manual_seed(0)
|
||||
>>> opset_version = 15
|
||||
>>> # Define a custom symbolic function for aten::relu.
|
||||
>>> # The custom symbolic function is incorrect, which will result in mismatches.
|
||||
>>> def incorrect_relu_symbolic_function(g, self):
|
||||
... return self
|
||||
>>> torch.onnx.register_custom_op_symbolic(
|
||||
... "aten::relu",
|
||||
... incorrect_relu_symbolic_function,
|
||||
... opset_version=opset_version,
|
||||
... )
|
||||
>>> class Model(torch.nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.layers = torch.nn.Sequential(
|
||||
... torch.nn.Linear(3, 4),
|
||||
... torch.nn.ReLU(),
|
||||
... torch.nn.Linear(4, 5),
|
||||
... torch.nn.ReLU(),
|
||||
... torch.nn.Linear(5, 6),
|
||||
... )
|
||||
... def forward(self, x):
|
||||
... return self.layers(x)
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
||||
>>> graph_info = torch.onnx.verification.find_mismatch(
|
||||
... Model(),
|
||||
... (torch.randn(2, 3),),
|
||||
... opset_version=opset_version,
|
||||
... )
|
||||
===================== Mismatch info for graph partition : ======================
|
||||
================================ Mismatch error ================================
|
||||
Tensor-likes are not close!
|
||||
Mismatched elements: 12 / 12 (100.0%)
|
||||
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
|
||||
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
|
||||
==================================== Tree: =====================================
|
||||
5 X __2 X __1 ✓
|
||||
id: | id: 0 | id: 00
|
||||
| |
|
||||
| |__1 X (aten::relu)
|
||||
| id: 01
|
||||
|
|
||||
|__3 X __1 ✓
|
||||
id: 1 | id: 10
|
||||
|
|
||||
|__2 X __1 X (aten::relu)
|
||||
id: 11 | id: 110
|
||||
|
|
||||
|__1 ✓
|
||||
id: 111
|
||||
=========================== Mismatch leaf subgraphs: ===========================
|
||||
['01', '110']
|
||||
============================= Mismatch node kinds: =============================
|
||||
{'aten::relu': 2}
|
||||
|
||||
"""
|
||||
if options is None:
|
||||
options = VerificationOptions()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user