BowenBao 2023-01-11 19:45:12 +00:00 committed by PyTorch MergeBot
parent ed7885c254
commit c537f5bee8
2 changed files with 128 additions and 26 deletions

View File

@ -706,6 +706,7 @@ Functions
.. autofunction:: is_in_onnx_export
.. autofunction:: enable_log
.. autofunction:: disable_log
.. autofunction:: torch.onnx.verification.find_mismatch
Classes
-------
@ -716,3 +717,5 @@ Classes
:template: classtemplate.rst
JitScalarType
torch.onnx.verification.GraphInfo
torch.onnx.verification.VerificationOptions

View File

@ -681,6 +681,9 @@ def _onnx_graph_from_aten_graph(
do_constant_folding = export_options.do_constant_folding
opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
GLOBALS.export_onnx_opset_version = opset_version
GLOBALS.operator_export_type = operator_export_type
do_constant_folding = utils._decide_constant_folding(
do_constant_folding, operator_export_type, training
)
@ -816,7 +819,7 @@ def verify(
] = None,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL,
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
opset_version: Optional[int] = None,
keep_initializers_as_inputs: bool = True,
verbose: bool = False,
@ -1197,6 +1200,8 @@ class OnnxTestCaseRepro:
@dataclasses.dataclass
class GraphInfo:
"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
graph: torch.Graph
input_args: Tuple[Any, ...]
params_dict: Dict[str, Any]
@ -1235,28 +1240,28 @@ class GraphInfo:
The id of the subgraph is shown under the node. The `GraphInfo` object for any
subgraph can be retrieved by calling `graph_info.find_partition(id)`.
Example:
```
==================================== Tree: =====================================
5 X __2 X __1
id: | id: 0 | id: 00
| |
| |__1 X (aten::relu)
| id: 01
|
|__3 X __1
id: 1 | id: 10
|
|__2 X __1 X (aten::relu)
id: 11 | id: 110
|
|__1
id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}
```
Example::
==================================== Tree: =====================================
5 X __2 X __1
id: | id: 0 | id: 00
| |
| |__1 X (aten::relu)
| id: 01
|
|__3 X __1
id: 1 | id: 10
|
|__2 X __1 X (aten::relu)
id: 11 | id: 110
|
|__1
id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}
"""
GraphInfoPrettyPrinter(self).pretty_print()
@ -1345,7 +1350,7 @@ class GraphInfo:
) -> str:
"""Export the subgraph to ONNX along with the input/output data for repro.
The repro directory will contain the following files:
The repro directory will contain the following files::
dir
test_<name>
@ -1612,6 +1617,24 @@ class GraphInfo:
def verify_export(
self, options: VerificationOptions
) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
"""
Verify the export from TorchScript IR graph to ONNX.
Export the TorchScript IR graph to ONNX, with the inputs, parameters and export
options recorded in this object. Then verify the exported ONNX graph against
the original TorchScript IR graph under the provided verification options.
Args:
options: The verification options.
Returns:
error: The AssertionError raised during the verification. Returns None if no
error is raised.
onnx_graph: The exported ONNX graph in TorchScript IR format.
onnx_outs: The outputs from running exported ONNX model under the onnx
backend in `options`.
pt_outs: The outputs from running the TorchScript IR graph.
"""
return verify_aten_graph(
self.graph,
input_args=self.input_args,
@ -1625,6 +1648,16 @@ class GraphInfo:
self,
options: Optional[VerificationOptions] = None,
):
"""
Find all mismatches between the TorchScript IR graph and the exported onnx model.
Binary searches the model graph to find the minimal subgraph that exhibits the
mismatch. A `GraphInfo` object is created for each subgraph, recording the test
inputs and export options, as well as the validation results.
Args:
options: The verification options.
"""
self.clear()
if options is None:
@ -1724,26 +1757,92 @@ def find_mismatch(
model: Union[torch.nn.Module, torch.jit.ScriptModule],
input_args: Tuple[Any, ...],
do_constant_folding: bool = True,
training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL,
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
opset_version: Optional[int] = None,
keep_initializers_as_inputs: bool = True,
verbose: bool = False,
options: Optional[VerificationOptions] = None,
) -> GraphInfo:
r"""Find all mismatches between the original model and the exported model.
TODO: Fill in docstring.
Experimental. The API is subject to change.
This tool helps debug the mismatch between the original PyTorch model and exported
ONNX model. It binary searches the model graph to find the minimal subgraph that
exhibits the mismatch.
Args:
model: The model to be exported.
input_args: The input arguments to the model.
do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`.
training: Same as `training` in :func:`torch.onnx.export`.
opset_version: Same as `opset_version` in :func:`torch.onnx.export`.
keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`.
verbose: Same as `verbose` in :func:`torch.onnx.export`.
options: The options for the mismatch verification.
Returns:
A GraphInfo object that contains the mismatch information.
Example::
>>> import torch
>>> import torch.onnx.verification
>>> torch.manual_seed(0)
>>> opset_version = 15
>>> # Define a custom symbolic function for aten::relu.
>>> # The custom symbolic function is incorrect, which will result in mismatches.
>>> def incorrect_relu_symbolic_function(g, self):
... return self
>>> torch.onnx.register_custom_op_symbolic(
... "aten::relu",
... incorrect_relu_symbolic_function,
... opset_version=opset_version,
... )
>>> class Model(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.layers = torch.nn.Sequential(
... torch.nn.Linear(3, 4),
... torch.nn.ReLU(),
... torch.nn.Linear(4, 5),
... torch.nn.ReLU(),
... torch.nn.Linear(5, 6),
... )
... def forward(self, x):
... return self.layers(x)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> graph_info = torch.onnx.verification.find_mismatch(
... Model(),
... (torch.randn(2, 3),),
... opset_version=opset_version,
... )
===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X __2 X __1
id: | id: 0 | id: 00
| |
| |__1 X (aten::relu)
| id: 01
|
|__3 X __1
id: 1 | id: 10
|
|__2 X __1 X (aten::relu)
id: 11 | id: 110
|
|__1
id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}
"""
if options is None:
options = VerificationOptions()