mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ONNX] Add acceptable_error_percentage to backend tests (#82622)
To enable more test coverage for numerically unstable tests that occasionally fails with sporadic element mismatch. Currently these tests have to be disabled in CI. This PR provides a configurable tolerance threshold for such issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82622 Approved by: https://github.com/justinchuby, https://github.com/abock
This commit is contained in:
parent
7896621f94
commit
404c1c04ff
|
|
@ -48,7 +48,7 @@ class TestModels(common_utils.TestCase):
|
|||
opset_version = 9 # Caffe2 doesn't support the default.
|
||||
keep_initializers_as_inputs = False
|
||||
|
||||
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
|
||||
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs):
|
||||
import caffe2.python.onnx.backend as backend
|
||||
|
||||
with torch.onnx.select_model_mode_for_export(
|
||||
|
|
@ -143,13 +143,11 @@ class TestModels(common_utils.TestCase):
|
|||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
|
||||
|
||||
@unittest.skip(
|
||||
"This test has been flaky on trunk and PRs. See https://github.com/pytorch/pytorch/issues/79540"
|
||||
)
|
||||
@skipScriptTest(min_opset_version=15) # None type in outputs
|
||||
# This test is numerically unstable. Sporadic single element mismatch occurs occasionally.
|
||||
def test_inception(self):
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
|
||||
self.exportTest(toC(inception_v3()), toC(x))
|
||||
self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01)
|
||||
|
||||
def test_squeezenet(self):
|
||||
# SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ from typing import List, Mapping, Tuple
|
|||
import onnx_test_common
|
||||
import parameterized
|
||||
import PIL
|
||||
import test_models
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
|
||||
from test_models import TestModels
|
||||
from torch import nn
|
||||
from torch.testing._internal import common_utils
|
||||
from torchvision import ops
|
||||
|
|
@ -27,20 +27,38 @@ from torchvision.models.detection import (
|
|||
)
|
||||
|
||||
|
||||
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
||||
def exportTest(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
rtol=1e-2,
|
||||
atol=1e-7,
|
||||
opset_versions=None,
|
||||
acceptable_error_percentage=None,
|
||||
):
|
||||
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14]
|
||||
|
||||
for opset_version in opset_versions:
|
||||
self.opset_version = opset_version
|
||||
self.onnx_shape_inference = True
|
||||
onnx_test_common.run_model_test(
|
||||
self, model, input_args=inputs, rtol=rtol, atol=atol
|
||||
self,
|
||||
model,
|
||||
input_args=inputs,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
acceptable_error_percentage=acceptable_error_percentage,
|
||||
)
|
||||
|
||||
if self.is_script_test_enabled and opset_version > 11:
|
||||
script_model = torch.jit.script(model)
|
||||
onnx_test_common.run_model_test(
|
||||
self, script_model, input_args=inputs, rtol=rtol, atol=atol
|
||||
self,
|
||||
script_model,
|
||||
input_args=inputs,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
acceptable_error_percentage=acceptable_error_percentage,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -48,7 +66,7 @@ TestModels = type(
|
|||
"TestModels",
|
||||
(common_utils.TestCase,),
|
||||
dict(
|
||||
TestModels.__dict__,
|
||||
test_models.TestModels.__dict__,
|
||||
is_script_test_enabled=False,
|
||||
is_script=False,
|
||||
exportTest=exportTest,
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ class _TestJITIRToONNX:
|
|||
atol=1e-7,
|
||||
check_shape=self.check_shape,
|
||||
check_dtype=self.check_dtype,
|
||||
acceptable_error_percentage=None,
|
||||
)
|
||||
|
||||
def test_example_ir(self):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.onnx import _experimental, verification
|
||||
from torch.testing._internal import common_utils
|
||||
|
|
@ -78,3 +80,34 @@ class TestVerification(common_utils.TestCase):
|
|||
SupportedModel(), test_input_groups
|
||||
)
|
||||
self.assertEqual(results, "")
|
||||
|
||||
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
|
||||
self,
|
||||
):
|
||||
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
|
||||
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
|
||||
verification._compare_ort_pytorch_outputs(
|
||||
ort_outs,
|
||||
pytorch_outs,
|
||||
rtol=1e-5,
|
||||
atol=1e-6,
|
||||
check_shape=True,
|
||||
check_dtype=False,
|
||||
acceptable_error_percentage=0.3,
|
||||
)
|
||||
|
||||
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
|
||||
self,
|
||||
):
|
||||
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
|
||||
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
|
||||
with self.assertRaises(AssertionError):
|
||||
verification._compare_ort_pytorch_outputs(
|
||||
ort_outs,
|
||||
pytorch_outs,
|
||||
rtol=1e-5,
|
||||
atol=1e-6,
|
||||
check_shape=True,
|
||||
check_dtype=False,
|
||||
acceptable_error_percentage=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -122,27 +122,64 @@ def _compare_ort_pytorch_outputs(
|
|||
atol: float,
|
||||
check_shape: bool,
|
||||
check_dtype: bool,
|
||||
acceptable_error_percentage: Optional[float],
|
||||
):
|
||||
"""
|
||||
Compare ONNX Runtime and PyTorch outputs.
|
||||
|
||||
Args:
|
||||
ort_outs: outputs from ONNX Runtime.
|
||||
pt_outs: outputs from PyTorch.
|
||||
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
|
||||
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
|
||||
acceptable_error_percentage (float, optional): acceptable percentage of element mismatches in comparison.
|
||||
It should be a float of value between 0.0 and 1.0.
|
||||
|
||||
Raises:
|
||||
AssertionError: if outputs from ONNX model and PyTorch model are not
|
||||
equal up to specified precision.
|
||||
ValueError: if arguments provided are invalid.
|
||||
"""
|
||||
pt_outs, _ = torch.jit._flatten(pt_outs)
|
||||
pt_outs = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
|
||||
|
||||
assert len(ort_outs) == len(
|
||||
pt_outs
|
||||
), f"Number of outputs differ ONNX runtime: ({len(ort_outs)}) PyTorch: ({len(pt_outs)})"
|
||||
if acceptable_error_percentage and (
|
||||
acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
|
||||
):
|
||||
raise ValueError(
|
||||
"If set, acceptable_error_percentage should be between 0.0 and 1.0"
|
||||
)
|
||||
|
||||
for ort_out, pt_out in zip(ort_outs, pt_outs):
|
||||
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
|
||||
if not check_shape:
|
||||
# Allow different but broadcastable output shapes.
|
||||
ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
|
||||
torch.testing.assert_close(
|
||||
ort_out,
|
||||
pt_out,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
check_dtype=check_dtype,
|
||||
equal_nan=True,
|
||||
)
|
||||
try:
|
||||
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
|
||||
if not check_shape:
|
||||
# Allow different but broadcastable output shapes.
|
||||
ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
|
||||
torch.testing.assert_close(
|
||||
ort_out,
|
||||
pt_out,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
check_dtype=check_dtype,
|
||||
equal_nan=True,
|
||||
)
|
||||
except AssertionError as e:
|
||||
if acceptable_error_percentage:
|
||||
error_percentage = 1 - np.sum(
|
||||
np.isclose(ort_out, pt_out, rtol=rtol, atol=atol)
|
||||
) / np.prod(ort_out.shape)
|
||||
if error_percentage <= acceptable_error_percentage:
|
||||
warnings.warn(
|
||||
f"Suppressed AssertionError:\n{e}.\n"
|
||||
f"Error percentage {error_percentage} "
|
||||
f"within acceptable range {acceptable_error_percentage}."
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
def _prepare_input_for_pytorch(args, kwargs):
|
||||
|
|
@ -243,6 +280,7 @@ def _compare_ort_pytorch_model(
|
|||
atol,
|
||||
check_shape,
|
||||
check_dtype,
|
||||
accetable_error_persentage: Optional[float],
|
||||
):
|
||||
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
|
||||
|
||||
|
|
@ -265,7 +303,13 @@ def _compare_ort_pytorch_model(
|
|||
ort_outs = _run_ort(ort_session, ort_inputs)
|
||||
|
||||
_compare_ort_pytorch_outputs(
|
||||
ort_outs, pt_outs, rtol, atol, check_shape, check_dtype
|
||||
ort_outs,
|
||||
pt_outs,
|
||||
rtol,
|
||||
atol,
|
||||
check_shape,
|
||||
check_dtype,
|
||||
accetable_error_persentage,
|
||||
)
|
||||
|
||||
compare_ort_pytorch_model_with_input(input_args, input_kwargs)
|
||||
|
|
@ -548,6 +592,7 @@ def verify(
|
|||
ort_providers: Sequence[str] = _ORT_PROVIDERS,
|
||||
rtol: float = 0.001,
|
||||
atol: float = 1e-7,
|
||||
acceptable_error_percentage: Optional[float] = None,
|
||||
**_,
|
||||
):
|
||||
"""Verify model export to ONNX with ONNX Runtime.
|
||||
|
|
@ -586,10 +631,13 @@ def verify(
|
|||
ort_providers (sequence, optional): ONNX Runtime providers to use.
|
||||
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
|
||||
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
|
||||
acceptable_error_percentage (float, optional): acceptable percentage of element mismatches in comparison.
|
||||
It should be a float of value between 0.0 and 1.0.
|
||||
|
||||
Raises:
|
||||
AssertionError: if outputs from ONNX model and PyTorch model are not
|
||||
equal up to specified precision.
|
||||
ValueError: if arguments provided are invalid.
|
||||
"""
|
||||
if training == torch.onnx.TrainingMode.TRAINING:
|
||||
model.train()
|
||||
|
|
@ -634,4 +682,5 @@ def verify(
|
|||
atol,
|
||||
check_shape,
|
||||
check_dtype,
|
||||
acceptable_error_percentage,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user