[ONNX] Create VerificationInterpreter (#148396)

An fx interpreter for comparing ONNX values with pytorch ones.

```py
import torch
from torch.onnx._internal.exporter._verification import VerificationInterpreter

class Model(torch.nn.Module):
    def forward(self, query, key, value):
        res = torch.nn.functional.scaled_dot_product_attention(
            query, key, value
        )
        rest = res.transpose(0, 1)
        return rest.view(8, 32, 128 * 64)

model = Model()

query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)

onnx_program = torch.onnx.export(model, (query, key, value), dynamo=True)
interpreter = VerificationInterpreter(onnx_program)
interpreter.run(query, key, value)
for info in interpreter.verification_infos:
    print(info)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148396
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-03-05 19:18:49 +00:00 committed by PyTorch MergeBot
parent 8af79b7ec8
commit 50e827b3df
3 changed files with 330 additions and 19 deletions

View File

@ -0,0 +1,73 @@
# Owner(s): ["module: onnx"]
"""Test the verification module."""
from __future__ import annotations
import torch
from torch.onnx._internal.exporter import _verification
from torch.testing._internal import common_utils
class VerificationInfoTest(common_utils.TestCase):
def test_from_tensors(self):
# Test with tensors
expected = torch.tensor([1.0, 2.0, 3.0])
actual = torch.tensor([1.0, 2.0, 3.0])
verification_info = _verification.VerificationInfo.from_tensors(
"test_tensor", expected, actual
)
self.assertEqual(verification_info.name, "test_tensor")
self.assertEqual(verification_info.max_abs_diff, 0)
self.assertEqual(verification_info.max_rel_diff, 0)
torch.testing.assert_close(
verification_info.abs_diff_hist[0], torch.tensor([3.0] + [0.0] * 8)
)
torch.testing.assert_close(
verification_info.rel_diff_hist[0], torch.tensor([3.0] + [0.0] * 8)
)
self.assertEqual(verification_info.expected_dtype, torch.float32)
self.assertEqual(verification_info.actual_dtype, torch.float32)
def test_from_tensors_int(self):
# Test with int tensors
expected = torch.tensor([1])
actual = 1
verification_info = _verification.VerificationInfo.from_tensors(
"test_tensor_int", expected, actual
)
self.assertEqual(verification_info.name, "test_tensor_int")
self.assertEqual(verification_info.max_abs_diff, 0)
self.assertEqual(verification_info.max_rel_diff, 0)
torch.testing.assert_close(
verification_info.abs_diff_hist[0], torch.tensor([1.0] + [0.0] * 8)
)
torch.testing.assert_close(
verification_info.rel_diff_hist[0], torch.tensor([1.0] + [0.0] * 8)
)
self.assertEqual(verification_info.expected_dtype, torch.int64)
self.assertEqual(verification_info.actual_dtype, torch.int64)
class VerificationInterpreterTest(common_utils.TestCase):
def test_interpreter_stores_correct_info(self):
class Model(torch.nn.Module):
def forward(self, a, b):
c = a + b
return c - 1
model = Model()
args = (torch.tensor([1.0]), torch.tensor([2.0]))
onnx_program = torch.onnx.export(model, args, dynamo=True, verbose=False)
assert onnx_program is not None
interpreter = _verification.VerificationInterpreter(onnx_program)
results = interpreter.run(args)
torch.testing.assert_close(results, model(*args))
verification_infos = interpreter.verification_infos
self.assertEqual(len(verification_infos), 3)
for info in verification_infos:
self.assertEqual(info.max_abs_diff, 0)
self.assertEqual(info.max_rel_diff, 0)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -5,6 +5,7 @@ from __future__ import annotations
__all__ = ["ONNXProgram"]
import contextlib
import copy
import gc
import logging
@ -61,6 +62,53 @@ def _count_initializer_size(graph: ir.Graph) -> int:
)
@contextlib.contextmanager
def _set_graph_outputs(
graph: ir.Graph,
outputs: list[ir.Value],
):
"""Temporarily set the outputs of the graph.
Args:
graph: The graph to set the outputs for.
outputs: The outputs to set.
"""
original_outputs = graph.outputs.copy()
graph.outputs.clear()
graph.outputs.extend(outputs)
try:
yield
finally:
graph.outputs.clear()
graph.outputs.extend(original_outputs)
def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
"""Return a dictionary mapping names to values in the graph.
The mapping does not include values from subgraphs.
Args:
graph: The graph to extract the mapping from.
Returns:
A dictionary mapping names to values.
"""
values = {}
values.update(graph.initializers)
# The names of the values can be None or "", which we need to exclude
for input in graph.inputs:
if not input.name:
continue
values[input.name] = input
for node in graph:
for value in node.outputs:
if not value.name:
continue
values[value.name] = value
return values
class ONNXProgram:
"""A class to represent an ONNX program that is callable with torch tensors."""
@ -112,6 +160,38 @@ ONNXProgram(
# TODO(justinchuby): Maybe output complex tensors as needed
return tuple(torch.from_numpy(output) for output in outputs)
def compute_values(
self, value_names: Sequence[str], args=(), kwargs=None
) -> Sequence[torch.Tensor]:
"""Compute the values of the specified names in the ONNX model.
This method is used to compute the values of the specified names in the ONNX model.
The values are returned as a dictionary mapping names to tensors.
Args:
value_names: The names of the values to compute.
Returns:
A dictionary mapping names to tensors.
"""
if kwargs is None:
kwargs = {}
self.release()
values = _create_value_mapping(self.model.graph)
for name in value_names:
if name not in values:
raise ValueError(
f"Value '{name}' not found in the model. "
"Please provide a valid value name."
)
temporary_outputs = [values[name] for name in value_names]
with _set_graph_outputs(self.model.graph, temporary_outputs):
try:
result = self(*args, **kwargs)
finally:
self.release()
return result
@property
def model_proto(self) -> onnx.ModelProto:
"""Return the ONNX ``ModelProto`` object."""

View File

@ -1,13 +1,14 @@
# mypy: allow-untyped-defs
from __future__ import annotations
__all__ = [
"VerificationInfo",
"VerificationInterpreter",
"verify_onnx_program",
]
import dataclasses
import logging
import math
from typing import Any, TYPE_CHECKING
@ -16,9 +17,14 @@ from torch.utils import _pytree
if TYPE_CHECKING:
from onnxscript import ir
from torch.onnx._internal.exporter import _onnx_program
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class VerificationInfo:
name: str
@ -31,6 +37,47 @@ class VerificationInfo:
# NOTE: We don't need to include shape because the expected shape is already known
# and checked by the runtime
@classmethod
def from_tensors(
cls,
name: str,
expected: torch.Tensor | int | float | bool,
actual: torch.Tensor | int | float | bool,
) -> VerificationInfo:
"""Create a VerificationInfo object from two tensors.
Args:
name: The name of the value.
expected: The expected tensor.
actual: The actual tensor.
Returns:
VerificationInfo: The VerificationInfo object.
"""
if not isinstance(expected, torch.Tensor):
expected = torch.tensor(expected)
if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
max_abs_diff, max_rel_diff, abs_diff, rel_diff = _compare_tensors(
expected, actual
)
bins = torch.tensor(
[0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000],
dtype=torch.float,
)
abs_diff_hist = torch.histogram(abs_diff.float(), bins=bins)
rel_diff_hist = torch.histogram(rel_diff.float(), bins=bins)
return cls(
name=name,
max_abs_diff=max_abs_diff,
max_rel_diff=max_rel_diff,
abs_diff_hist=abs_diff_hist,
rel_diff_hist=rel_diff_hist,
expected_dtype=expected.dtype,
actual_dtype=actual.dtype,
)
def _compare_tensors(
expected: torch.Tensor,
@ -86,26 +133,137 @@ def verify_onnx_program(
torch_outputs, onnx_outputs, onnx_program.model.graph.outputs
):
name = output_val.name
max_abs_diff, max_rel_diff, abs_diff, rel_diff = _compare_tensors(
torch_output, onnx_output
)
abs_diff = abs_diff.flatten()
rel_diff = rel_diff.flatten()
bins = torch.tensor(
[0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000],
dtype=abs_diff.dtype,
)
abs_diff_hist = torch.histogram(abs_diff, bins=bins)
rel_diff_hist = torch.histogram(rel_diff, bins=bins)
results.append(
VerificationInfo(
VerificationInfo.from_tensors(
name=str(name),
max_abs_diff=max_abs_diff,
max_rel_diff=max_rel_diff,
abs_diff_hist=abs_diff_hist,
rel_diff_hist=rel_diff_hist,
expected_dtype=torch_output.dtype,
actual_dtype=onnx_output.dtype,
expected=torch_output,
actual=onnx_output,
)
)
return results
def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
"""Return a dictionary mapping names to values in the graph.
The mapping does not include values from subgraphs.
Args:
graph: The graph to extract the mapping from.
Returns:
A dictionary mapping names to values.
"""
values = {}
values.update(graph.initializers)
# The names of the values can be None or "", which we need to exclude
for input in graph.inputs:
if not input.name:
continue
values[input.name] = input
for node in graph:
for value in node.outputs:
if not value.name:
continue
values[value.name] = value
return values
class VerificationInterpreter(torch.fx.Interpreter):
"""Interpreter for verifying converted ONNX model accuracy by comparing intermediate values.
To compare models, first initialize the interpreter with an ONNX program.
Then, call the :meth:`run` method with the input arguments to execute the model.
The :meth:`run` method will execute the model and populate the
:attr:`verification_infos` attribute with the verification information for each value.
::
onnx_program = torch.onnx.export(model, args, dynamo=True)
interpreter = VerificationInterpreter(onnx_program)
interpreter.run(*args)
verification_infos = interpreter.verification_infos
for info in verification_infos:
print("value name:", info.name, info)
The verification information includes the maximum absolute difference, maximum relative
difference, and histograms of absolute and relative differences between the expected
and actual values. See :class:`VerificationInfo` for more details.
Attributes:
verification_infos: A list of verification information for each value.
It is populated when the `run` method is called.
"""
def __init__(self, onnx_program: torch.onnx.ONNXProgram) -> None:
"""Initialize the VerificationInterpreter with an ONNX program.
Args:
onnx_program: The ONNX program to verify.
"""
if onnx_program.exported_program is None:
raise ValueError(
"The ONNX program does not contain an exported_program. "
"Please provide an exported_program to verify the ONNX program."
)
super().__init__(onnx_program.exported_program.module())
self._onnx_program = onnx_program
self._onnx_values = _create_value_mapping(onnx_program.model.graph)
self._args: list[Any] = []
self.verification_infos: list[VerificationInfo] = []
def run(
self,
*args: Any,
initial_env: dict[torch.fx.Node, Any] | None = None,
enable_io_processing: bool = True,
) -> Any:
"""Run the interpreter with the given input arguments.
This method executes the model and populates the :attr:`verification_infos` attribute
with the verification information for each value.
Args:
args: The input arguments for the model.
initial_env: The initial environment for the interpreter.
enable_io_processing: Whether to enable IO processing.
Returns:
Any: The result of executing the model.
"""
self.verification_infos = []
self.args = args
return super().run(
*args,
initial_env=initial_env,
enable_io_processing=enable_io_processing,
)
def run_node(self, n: torch.fx.Node) -> Any:
result = super().run_node(n)
if n.op != "call_function":
return result
node_name = n.name
if node_name not in self._onnx_values:
return result
(onnx_result,) = self._onnx_program.compute_values([node_name], self.args)
info = VerificationInfo.from_tensors(
name=node_name,
expected=result,
actual=onnx_result,
)
self.verification_infos.append(info)
if info.max_abs_diff > 0.01 or info.max_rel_diff > 0.1:
logger.warning(
"Verification info for node %s: max_abs_diff: %s, max_rel_diff: %s",
node_name,
info.max_abs_diff,
info.max_rel_diff,
)
else:
logger.info(
"Verification info for node %s: max_abs_diff: %s, max_rel_diff: %s",
node_name,
info.max_abs_diff,
info.max_rel_diff,
)
return result