pytorch/test/onnx/test_verification.py
titaiwang 5bceaadb70 [ONNX] Add script/trace different flatten and move optional type tests to runtime (#83184)
fix #78119

Why:
As in onnx tests verification code, we used to only consider tracing output, which ignores None type, this PR enables runtime test to keep None type in torch in script mode.

1. Move Optional Type tests from no runtime to runtime, as it's supported by ONNXRUNTIME.
2. Add ignoreNone flag for output comparison of internal tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83184
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-08-30 18:23:24 +00:00

116 lines
3.7 KiB
Python

# Owner(s): ["module: onnx"]
import numpy as np
import torch
from torch.onnx import _experimental, verification
from torch.testing._internal import common_utils
class TestVerification(common_utils.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
class UnexportableModel(torch.nn.Module):
def forward(self, x, y):
# tensor.data() will be exported as a constant,
# leading to wrong model output under different inputs.
return x + y.data
test_input_groups = [
((torch.randn(2, 3), torch.randn(2, 3)), {}),
((torch.randn(2, 3), torch.randn(2, 3)), {}),
]
results = verification.check_export_model_diff(
UnexportableModel(), test_input_groups
)
self.assertRegex(
results,
r"Graph diff:(.|\n)*"
r"First diverging operator:(.|\n)*"
r"prim::Constant(.|\n)*"
r"Former source location:(.|\n)*"
r"Latter source location:",
)
def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
self,
):
class UnexportableModel(torch.nn.Module):
def forward(self, x, y):
for i in range(x.size(0)):
y = x[i] + y
return y
test_input_groups = [
((torch.randn(2, 3), torch.randn(2, 3)), {}),
((torch.randn(4, 3), torch.randn(2, 3)), {}),
]
export_options = _experimental.ExportOptions(
input_names=["x", "y"], dynamic_axes={"x": [0]}
)
results = verification.check_export_model_diff(
UnexportableModel(), test_input_groups, export_options
)
self.assertRegex(
results,
r"Graph diff:(.|\n)*"
r"First diverging operator:(.|\n)*"
r"prim::Constant(.|\n)*"
r"Latter source location:(.|\n)*",
)
def test_check_export_model_diff_returns_empty_when_correct_export(self):
class SupportedModel(torch.nn.Module):
def forward(self, x, y):
return x + y
test_input_groups = [
((torch.randn(2, 3), torch.randn(2, 3)), {}),
((torch.randn(2, 3), torch.randn(2, 3)), {}),
]
results = verification.check_export_model_diff(
SupportedModel(), test_input_groups
)
self.assertEqual(results, "")
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
self,
):
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
verification._compare_ort_pytorch_outputs(
ort_outs,
pytorch_outs,
rtol=1e-5,
atol=1e-6,
check_shape=True,
check_dtype=False,
ignore_none=True,
acceptable_error_percentage=0.3,
)
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
self,
):
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
with self.assertRaises(AssertionError):
verification._compare_ort_pytorch_outputs(
ort_outs,
pytorch_outs,
rtol=1e-5,
atol=1e-6,
check_shape=True,
check_dtype=False,
ignore_none=True,
acceptable_error_percentage=None,
)