mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fix #78119 Why: As in onnx tests verification code, we used to only consider tracing output, which ignores None type, this PR enables runtime test to keep None type in torch in script mode. 1. Move Optional Type tests from no runtime to runtime, as it's supported by ONNXRUNTIME. 2. Add ignoreNone flag for output comparison of internal tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/83184 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.onnx import _experimental, verification
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
class TestVerification(common_utils.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
|
|
def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
|
|
class UnexportableModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
# tensor.data() will be exported as a constant,
|
|
# leading to wrong model output under different inputs.
|
|
return x + y.data
|
|
|
|
test_input_groups = [
|
|
((torch.randn(2, 3), torch.randn(2, 3)), {}),
|
|
((torch.randn(2, 3), torch.randn(2, 3)), {}),
|
|
]
|
|
|
|
results = verification.check_export_model_diff(
|
|
UnexportableModel(), test_input_groups
|
|
)
|
|
self.assertRegex(
|
|
results,
|
|
r"Graph diff:(.|\n)*"
|
|
r"First diverging operator:(.|\n)*"
|
|
r"prim::Constant(.|\n)*"
|
|
r"Former source location:(.|\n)*"
|
|
r"Latter source location:",
|
|
)
|
|
|
|
def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
|
|
self,
|
|
):
|
|
class UnexportableModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
for i in range(x.size(0)):
|
|
y = x[i] + y
|
|
return y
|
|
|
|
test_input_groups = [
|
|
((torch.randn(2, 3), torch.randn(2, 3)), {}),
|
|
((torch.randn(4, 3), torch.randn(2, 3)), {}),
|
|
]
|
|
|
|
export_options = _experimental.ExportOptions(
|
|
input_names=["x", "y"], dynamic_axes={"x": [0]}
|
|
)
|
|
results = verification.check_export_model_diff(
|
|
UnexportableModel(), test_input_groups, export_options
|
|
)
|
|
self.assertRegex(
|
|
results,
|
|
r"Graph diff:(.|\n)*"
|
|
r"First diverging operator:(.|\n)*"
|
|
r"prim::Constant(.|\n)*"
|
|
r"Latter source location:(.|\n)*",
|
|
)
|
|
|
|
def test_check_export_model_diff_returns_empty_when_correct_export(self):
|
|
class SupportedModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
test_input_groups = [
|
|
((torch.randn(2, 3), torch.randn(2, 3)), {}),
|
|
((torch.randn(2, 3), torch.randn(2, 3)), {}),
|
|
]
|
|
|
|
results = verification.check_export_model_diff(
|
|
SupportedModel(), test_input_groups
|
|
)
|
|
self.assertEqual(results, "")
|
|
|
|
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
|
|
self,
|
|
):
|
|
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
|
|
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
|
|
verification._compare_ort_pytorch_outputs(
|
|
ort_outs,
|
|
pytorch_outs,
|
|
rtol=1e-5,
|
|
atol=1e-6,
|
|
check_shape=True,
|
|
check_dtype=False,
|
|
ignore_none=True,
|
|
acceptable_error_percentage=0.3,
|
|
)
|
|
|
|
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
|
|
self,
|
|
):
|
|
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
|
|
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
|
|
with self.assertRaises(AssertionError):
|
|
verification._compare_ort_pytorch_outputs(
|
|
ort_outs,
|
|
pytorch_outs,
|
|
rtol=1e-5,
|
|
atol=1e-6,
|
|
check_shape=True,
|
|
check_dtype=False,
|
|
ignore_none=True,
|
|
acceptable_error_percentage=None,
|
|
)
|