mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Create VerificationInterpreter (#148396)
An fx interpreter for comparing ONNX values with pytorch ones.
```py
import torch
from torch.onnx._internal.exporter._verification import VerificationInterpreter
class Model(torch.nn.Module):
def forward(self, query, key, value):
res = torch.nn.functional.scaled_dot_product_attention(
query, key, value
)
rest = res.transpose(0, 1)
return rest.view(8, 32, 128 * 64)
model = Model()
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
onnx_program = torch.onnx.export(model, (query, key, value), dynamo=True)
interpreter = VerificationInterpreter(onnx_program)
interpreter.run(query, key, value)
for info in interpreter.verification_infos:
print(info)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148396
Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
8af79b7ec8
commit
50e827b3df
73
test/onnx/exporter/test_verification.py
Normal file
73
test/onnx/exporter/test_verification.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
"""Test the verification module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _verification
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class VerificationInfoTest(common_utils.TestCase):
|
||||
def test_from_tensors(self):
|
||||
# Test with tensors
|
||||
expected = torch.tensor([1.0, 2.0, 3.0])
|
||||
actual = torch.tensor([1.0, 2.0, 3.0])
|
||||
verification_info = _verification.VerificationInfo.from_tensors(
|
||||
"test_tensor", expected, actual
|
||||
)
|
||||
self.assertEqual(verification_info.name, "test_tensor")
|
||||
self.assertEqual(verification_info.max_abs_diff, 0)
|
||||
self.assertEqual(verification_info.max_rel_diff, 0)
|
||||
torch.testing.assert_close(
|
||||
verification_info.abs_diff_hist[0], torch.tensor([3.0] + [0.0] * 8)
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
verification_info.rel_diff_hist[0], torch.tensor([3.0] + [0.0] * 8)
|
||||
)
|
||||
self.assertEqual(verification_info.expected_dtype, torch.float32)
|
||||
self.assertEqual(verification_info.actual_dtype, torch.float32)
|
||||
|
||||
def test_from_tensors_int(self):
|
||||
# Test with int tensors
|
||||
expected = torch.tensor([1])
|
||||
actual = 1
|
||||
verification_info = _verification.VerificationInfo.from_tensors(
|
||||
"test_tensor_int", expected, actual
|
||||
)
|
||||
self.assertEqual(verification_info.name, "test_tensor_int")
|
||||
self.assertEqual(verification_info.max_abs_diff, 0)
|
||||
self.assertEqual(verification_info.max_rel_diff, 0)
|
||||
torch.testing.assert_close(
|
||||
verification_info.abs_diff_hist[0], torch.tensor([1.0] + [0.0] * 8)
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
verification_info.rel_diff_hist[0], torch.tensor([1.0] + [0.0] * 8)
|
||||
)
|
||||
self.assertEqual(verification_info.expected_dtype, torch.int64)
|
||||
self.assertEqual(verification_info.actual_dtype, torch.int64)
|
||||
|
||||
|
||||
class VerificationInterpreterTest(common_utils.TestCase):
|
||||
def test_interpreter_stores_correct_info(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
c = a + b
|
||||
return c - 1
|
||||
|
||||
model = Model()
|
||||
args = (torch.tensor([1.0]), torch.tensor([2.0]))
|
||||
onnx_program = torch.onnx.export(model, args, dynamo=True, verbose=False)
|
||||
assert onnx_program is not None
|
||||
interpreter = _verification.VerificationInterpreter(onnx_program)
|
||||
results = interpreter.run(args)
|
||||
torch.testing.assert_close(results, model(*args))
|
||||
verification_infos = interpreter.verification_infos
|
||||
self.assertEqual(len(verification_infos), 3)
|
||||
for info in verification_infos:
|
||||
self.assertEqual(info.max_abs_diff, 0)
|
||||
self.assertEqual(info.max_rel_diff, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
|
||||
__all__ = ["ONNXProgram"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import logging
|
||||
|
|
@ -61,6 +62,53 @@ def _count_initializer_size(graph: ir.Graph) -> int:
|
|||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_graph_outputs(
|
||||
graph: ir.Graph,
|
||||
outputs: list[ir.Value],
|
||||
):
|
||||
"""Temporarily set the outputs of the graph.
|
||||
|
||||
Args:
|
||||
graph: The graph to set the outputs for.
|
||||
outputs: The outputs to set.
|
||||
"""
|
||||
original_outputs = graph.outputs.copy()
|
||||
graph.outputs.clear()
|
||||
graph.outputs.extend(outputs)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
graph.outputs.clear()
|
||||
graph.outputs.extend(original_outputs)
|
||||
|
||||
|
||||
def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
|
||||
"""Return a dictionary mapping names to values in the graph.
|
||||
|
||||
The mapping does not include values from subgraphs.
|
||||
|
||||
Args:
|
||||
graph: The graph to extract the mapping from.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping names to values.
|
||||
"""
|
||||
values = {}
|
||||
values.update(graph.initializers)
|
||||
# The names of the values can be None or "", which we need to exclude
|
||||
for input in graph.inputs:
|
||||
if not input.name:
|
||||
continue
|
||||
values[input.name] = input
|
||||
for node in graph:
|
||||
for value in node.outputs:
|
||||
if not value.name:
|
||||
continue
|
||||
values[value.name] = value
|
||||
return values
|
||||
|
||||
|
||||
class ONNXProgram:
|
||||
"""A class to represent an ONNX program that is callable with torch tensors."""
|
||||
|
||||
|
|
@ -112,6 +160,38 @@ ONNXProgram(
|
|||
# TODO(justinchuby): Maybe output complex tensors as needed
|
||||
return tuple(torch.from_numpy(output) for output in outputs)
|
||||
|
||||
def compute_values(
|
||||
self, value_names: Sequence[str], args=(), kwargs=None
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""Compute the values of the specified names in the ONNX model.
|
||||
|
||||
This method is used to compute the values of the specified names in the ONNX model.
|
||||
The values are returned as a dictionary mapping names to tensors.
|
||||
|
||||
Args:
|
||||
value_names: The names of the values to compute.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping names to tensors.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
self.release()
|
||||
values = _create_value_mapping(self.model.graph)
|
||||
for name in value_names:
|
||||
if name not in values:
|
||||
raise ValueError(
|
||||
f"Value '{name}' not found in the model. "
|
||||
"Please provide a valid value name."
|
||||
)
|
||||
temporary_outputs = [values[name] for name in value_names]
|
||||
with _set_graph_outputs(self.model.graph, temporary_outputs):
|
||||
try:
|
||||
result = self(*args, **kwargs)
|
||||
finally:
|
||||
self.release()
|
||||
return result
|
||||
|
||||
@property
|
||||
def model_proto(self) -> onnx.ModelProto:
|
||||
"""Return the ONNX ``ModelProto`` object."""
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VerificationInfo",
|
||||
"VerificationInterpreter",
|
||||
"verify_onnx_program",
|
||||
]
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
|
|
@ -16,9 +17,14 @@ from torch.utils import _pytree
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onnxscript import ir
|
||||
|
||||
from torch.onnx._internal.exporter import _onnx_program
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VerificationInfo:
|
||||
name: str
|
||||
|
|
@ -31,6 +37,47 @@ class VerificationInfo:
|
|||
# NOTE: We don't need to include shape because the expected shape is already known
|
||||
# and checked by the runtime
|
||||
|
||||
@classmethod
|
||||
def from_tensors(
|
||||
cls,
|
||||
name: str,
|
||||
expected: torch.Tensor | int | float | bool,
|
||||
actual: torch.Tensor | int | float | bool,
|
||||
) -> VerificationInfo:
|
||||
"""Create a VerificationInfo object from two tensors.
|
||||
|
||||
Args:
|
||||
name: The name of the value.
|
||||
expected: The expected tensor.
|
||||
actual: The actual tensor.
|
||||
|
||||
Returns:
|
||||
VerificationInfo: The VerificationInfo object.
|
||||
"""
|
||||
if not isinstance(expected, torch.Tensor):
|
||||
expected = torch.tensor(expected)
|
||||
if not isinstance(actual, torch.Tensor):
|
||||
actual = torch.tensor(actual)
|
||||
|
||||
max_abs_diff, max_rel_diff, abs_diff, rel_diff = _compare_tensors(
|
||||
expected, actual
|
||||
)
|
||||
bins = torch.tensor(
|
||||
[0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000],
|
||||
dtype=torch.float,
|
||||
)
|
||||
abs_diff_hist = torch.histogram(abs_diff.float(), bins=bins)
|
||||
rel_diff_hist = torch.histogram(rel_diff.float(), bins=bins)
|
||||
return cls(
|
||||
name=name,
|
||||
max_abs_diff=max_abs_diff,
|
||||
max_rel_diff=max_rel_diff,
|
||||
abs_diff_hist=abs_diff_hist,
|
||||
rel_diff_hist=rel_diff_hist,
|
||||
expected_dtype=expected.dtype,
|
||||
actual_dtype=actual.dtype,
|
||||
)
|
||||
|
||||
|
||||
def _compare_tensors(
|
||||
expected: torch.Tensor,
|
||||
|
|
@ -86,26 +133,137 @@ def verify_onnx_program(
|
|||
torch_outputs, onnx_outputs, onnx_program.model.graph.outputs
|
||||
):
|
||||
name = output_val.name
|
||||
max_abs_diff, max_rel_diff, abs_diff, rel_diff = _compare_tensors(
|
||||
torch_output, onnx_output
|
||||
)
|
||||
abs_diff = abs_diff.flatten()
|
||||
rel_diff = rel_diff.flatten()
|
||||
bins = torch.tensor(
|
||||
[0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000],
|
||||
dtype=abs_diff.dtype,
|
||||
)
|
||||
abs_diff_hist = torch.histogram(abs_diff, bins=bins)
|
||||
rel_diff_hist = torch.histogram(rel_diff, bins=bins)
|
||||
results.append(
|
||||
VerificationInfo(
|
||||
VerificationInfo.from_tensors(
|
||||
name=str(name),
|
||||
max_abs_diff=max_abs_diff,
|
||||
max_rel_diff=max_rel_diff,
|
||||
abs_diff_hist=abs_diff_hist,
|
||||
rel_diff_hist=rel_diff_hist,
|
||||
expected_dtype=torch_output.dtype,
|
||||
actual_dtype=onnx_output.dtype,
|
||||
expected=torch_output,
|
||||
actual=onnx_output,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
|
||||
"""Return a dictionary mapping names to values in the graph.
|
||||
|
||||
The mapping does not include values from subgraphs.
|
||||
|
||||
Args:
|
||||
graph: The graph to extract the mapping from.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping names to values.
|
||||
"""
|
||||
values = {}
|
||||
values.update(graph.initializers)
|
||||
# The names of the values can be None or "", which we need to exclude
|
||||
for input in graph.inputs:
|
||||
if not input.name:
|
||||
continue
|
||||
values[input.name] = input
|
||||
for node in graph:
|
||||
for value in node.outputs:
|
||||
if not value.name:
|
||||
continue
|
||||
values[value.name] = value
|
||||
return values
|
||||
|
||||
|
||||
class VerificationInterpreter(torch.fx.Interpreter):
|
||||
"""Interpreter for verifying converted ONNX model accuracy by comparing intermediate values.
|
||||
|
||||
To compare models, first initialize the interpreter with an ONNX program.
|
||||
Then, call the :meth:`run` method with the input arguments to execute the model.
|
||||
The :meth:`run` method will execute the model and populate the
|
||||
:attr:`verification_infos` attribute with the verification information for each value.
|
||||
|
||||
::
|
||||
onnx_program = torch.onnx.export(model, args, dynamo=True)
|
||||
interpreter = VerificationInterpreter(onnx_program)
|
||||
interpreter.run(*args)
|
||||
verification_infos = interpreter.verification_infos
|
||||
for info in verification_infos:
|
||||
print("value name:", info.name, info)
|
||||
|
||||
The verification information includes the maximum absolute difference, maximum relative
|
||||
difference, and histograms of absolute and relative differences between the expected
|
||||
and actual values. See :class:`VerificationInfo` for more details.
|
||||
|
||||
Attributes:
|
||||
verification_infos: A list of verification information for each value.
|
||||
It is populated when the `run` method is called.
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_program: torch.onnx.ONNXProgram) -> None:
|
||||
"""Initialize the VerificationInterpreter with an ONNX program.
|
||||
|
||||
Args:
|
||||
onnx_program: The ONNX program to verify.
|
||||
"""
|
||||
if onnx_program.exported_program is None:
|
||||
raise ValueError(
|
||||
"The ONNX program does not contain an exported_program. "
|
||||
"Please provide an exported_program to verify the ONNX program."
|
||||
)
|
||||
super().__init__(onnx_program.exported_program.module())
|
||||
self._onnx_program = onnx_program
|
||||
self._onnx_values = _create_value_mapping(onnx_program.model.graph)
|
||||
self._args: list[Any] = []
|
||||
self.verification_infos: list[VerificationInfo] = []
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
initial_env: dict[torch.fx.Node, Any] | None = None,
|
||||
enable_io_processing: bool = True,
|
||||
) -> Any:
|
||||
"""Run the interpreter with the given input arguments.
|
||||
|
||||
This method executes the model and populates the :attr:`verification_infos` attribute
|
||||
with the verification information for each value.
|
||||
|
||||
Args:
|
||||
args: The input arguments for the model.
|
||||
initial_env: The initial environment for the interpreter.
|
||||
enable_io_processing: Whether to enable IO processing.
|
||||
|
||||
Returns:
|
||||
Any: The result of executing the model.
|
||||
"""
|
||||
self.verification_infos = []
|
||||
self.args = args
|
||||
return super().run(
|
||||
*args,
|
||||
initial_env=initial_env,
|
||||
enable_io_processing=enable_io_processing,
|
||||
)
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
result = super().run_node(n)
|
||||
if n.op != "call_function":
|
||||
return result
|
||||
node_name = n.name
|
||||
if node_name not in self._onnx_values:
|
||||
return result
|
||||
(onnx_result,) = self._onnx_program.compute_values([node_name], self.args)
|
||||
info = VerificationInfo.from_tensors(
|
||||
name=node_name,
|
||||
expected=result,
|
||||
actual=onnx_result,
|
||||
)
|
||||
self.verification_infos.append(info)
|
||||
if info.max_abs_diff > 0.01 or info.max_rel_diff > 0.1:
|
||||
logger.warning(
|
||||
"Verification info for node %s: max_abs_diff: %s, max_rel_diff: %s",
|
||||
node_name,
|
||||
info.max_abs_diff,
|
||||
info.max_rel_diff,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Verification info for node %s: max_abs_diff: %s, max_rel_diff: %s",
|
||||
node_name,
|
||||
info.max_abs_diff,
|
||||
info.max_rel_diff,
|
||||
)
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user