pytorch/test/onnx/exporter/test_verification.py
Justin Chu 50e827b3df [ONNX] Create VerificationInterpreter (#148396)
An fx interpreter for comparing ONNX values with pytorch ones.

```py
import torch
from torch.onnx._internal.exporter._verification import VerificationInterpreter

class Model(torch.nn.Module):
    def forward(self, query, key, value):
        res = torch.nn.functional.scaled_dot_product_attention(
            query, key, value
        )
        rest = res.transpose(0, 1)
        return rest.view(8, 32, 128 * 64)

model = Model()

query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)

onnx_program = torch.onnx.export(model, (query, key, value), dynamo=True)
interpreter = VerificationInterpreter(onnx_program)
interpreter.run(query, key, value)
for info in interpreter.verification_infos:
    print(info)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148396
Approved by: https://github.com/titaiwangms
2025-03-05 19:18:52 +00:00

74 lines
2.8 KiB
Python

# Owner(s): ["module: onnx"]
"""Test the verification module."""
from __future__ import annotations
import torch
from torch.onnx._internal.exporter import _verification
from torch.testing._internal import common_utils
class VerificationInfoTest(common_utils.TestCase):
def test_from_tensors(self):
# Test with tensors
expected = torch.tensor([1.0, 2.0, 3.0])
actual = torch.tensor([1.0, 2.0, 3.0])
verification_info = _verification.VerificationInfo.from_tensors(
"test_tensor", expected, actual
)
self.assertEqual(verification_info.name, "test_tensor")
self.assertEqual(verification_info.max_abs_diff, 0)
self.assertEqual(verification_info.max_rel_diff, 0)
torch.testing.assert_close(
verification_info.abs_diff_hist[0], torch.tensor([3.0] + [0.0] * 8)
)
torch.testing.assert_close(
verification_info.rel_diff_hist[0], torch.tensor([3.0] + [0.0] * 8)
)
self.assertEqual(verification_info.expected_dtype, torch.float32)
self.assertEqual(verification_info.actual_dtype, torch.float32)
def test_from_tensors_int(self):
# Test with int tensors
expected = torch.tensor([1])
actual = 1
verification_info = _verification.VerificationInfo.from_tensors(
"test_tensor_int", expected, actual
)
self.assertEqual(verification_info.name, "test_tensor_int")
self.assertEqual(verification_info.max_abs_diff, 0)
self.assertEqual(verification_info.max_rel_diff, 0)
torch.testing.assert_close(
verification_info.abs_diff_hist[0], torch.tensor([1.0] + [0.0] * 8)
)
torch.testing.assert_close(
verification_info.rel_diff_hist[0], torch.tensor([1.0] + [0.0] * 8)
)
self.assertEqual(verification_info.expected_dtype, torch.int64)
self.assertEqual(verification_info.actual_dtype, torch.int64)
class VerificationInterpreterTest(common_utils.TestCase):
def test_interpreter_stores_correct_info(self):
class Model(torch.nn.Module):
def forward(self, a, b):
c = a + b
return c - 1
model = Model()
args = (torch.tensor([1.0]), torch.tensor([2.0]))
onnx_program = torch.onnx.export(model, args, dynamo=True, verbose=False)
assert onnx_program is not None
interpreter = _verification.VerificationInterpreter(onnx_program)
results = interpreter.run(args)
torch.testing.assert_close(results, model(*args))
verification_infos = interpreter.verification_infos
self.assertEqual(len(verification_infos), 3)
for info in verification_infos:
self.assertEqual(info.max_abs_diff, 0)
self.assertEqual(info.max_rel_diff, 0)
if __name__ == "__main__":
common_utils.run_tests()