[ONNX] Add acceptable_error_percentage to backend tests (#82622)

To enable more test coverage for numerically unstable tests that
occasionally fails with sporadic element mismatch. Currently these
tests have to be disabled in CI. This PR provides a configurable
tolerance threshold for such issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82622
Approved by: https://github.com/justinchuby, https://github.com/abock
This commit is contained in:
BowenBao 2022-08-11 14:27:27 -07:00 committed by PyTorch MergeBot
parent 7896621f94
commit 404c1c04ff
5 changed files with 122 additions and 23 deletions

View File

@ -48,7 +48,7 @@ class TestModels(common_utils.TestCase):
opset_version = 9 # Caffe2 doesn't support the default.
keep_initializers_as_inputs = False
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs):
import caffe2.python.onnx.backend as backend
with torch.onnx.select_model_mode_for_export(
@ -143,13 +143,11 @@ class TestModels(common_utils.TestCase):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
@unittest.skip(
"This test has been flaky on trunk and PRs. See https://github.com/pytorch/pytorch/issues/79540"
)
@skipScriptTest(min_opset_version=15) # None type in outputs
# This test is numerically unstable. Sporadic single element mismatch occurs occasionally.
def test_inception(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
self.exportTest(toC(inception_v3()), toC(x))
self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01)
def test_squeezenet(self):
# SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and

View File

@ -8,11 +8,11 @@ from typing import List, Mapping, Tuple
import onnx_test_common
import parameterized
import PIL
import test_models
import torch
import torchvision
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
from test_models import TestModels
from torch import nn
from torch.testing._internal import common_utils
from torchvision import ops
@ -27,20 +27,38 @@ from torchvision.models.detection import (
)
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
def exportTest(
self,
model,
inputs,
rtol=1e-2,
atol=1e-7,
opset_versions=None,
acceptable_error_percentage=None,
):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14]
for opset_version in opset_versions:
self.opset_version = opset_version
self.onnx_shape_inference = True
onnx_test_common.run_model_test(
self, model, input_args=inputs, rtol=rtol, atol=atol
self,
model,
input_args=inputs,
rtol=rtol,
atol=atol,
acceptable_error_percentage=acceptable_error_percentage,
)
if self.is_script_test_enabled and opset_version > 11:
script_model = torch.jit.script(model)
onnx_test_common.run_model_test(
self, script_model, input_args=inputs, rtol=rtol, atol=atol
self,
script_model,
input_args=inputs,
rtol=rtol,
atol=atol,
acceptable_error_percentage=acceptable_error_percentage,
)
@ -48,7 +66,7 @@ TestModels = type(
"TestModels",
(common_utils.TestCase,),
dict(
TestModels.__dict__,
test_models.TestModels.__dict__,
is_script_test_enabled=False,
is_script=False,
exportTest=exportTest,

View File

@ -69,6 +69,7 @@ class _TestJITIRToONNX:
atol=1e-7,
check_shape=self.check_shape,
check_dtype=self.check_dtype,
acceptable_error_percentage=None,
)
def test_example_ir(self):

View File

@ -1,5 +1,7 @@
# Owner(s): ["module: onnx"]
import numpy as np
import torch
from torch.onnx import _experimental, verification
from torch.testing._internal import common_utils
@ -78,3 +80,34 @@ class TestVerification(common_utils.TestCase):
SupportedModel(), test_input_groups
)
self.assertEqual(results, "")
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
self,
):
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
verification._compare_ort_pytorch_outputs(
ort_outs,
pytorch_outs,
rtol=1e-5,
atol=1e-6,
check_shape=True,
check_dtype=False,
acceptable_error_percentage=0.3,
)
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
self,
):
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
with self.assertRaises(AssertionError):
verification._compare_ort_pytorch_outputs(
ort_outs,
pytorch_outs,
rtol=1e-5,
atol=1e-6,
check_shape=True,
check_dtype=False,
acceptable_error_percentage=None,
)

View File

@ -122,27 +122,64 @@ def _compare_ort_pytorch_outputs(
atol: float,
check_shape: bool,
check_dtype: bool,
acceptable_error_percentage: Optional[float],
):
"""
Compare ONNX Runtime and PyTorch outputs.
Args:
ort_outs: outputs from ONNX Runtime.
pt_outs: outputs from PyTorch.
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
acceptable_error_percentage (float, optional): acceptable percentage of element mismatches in comparison.
It should be a float of value between 0.0 and 1.0.
Raises:
AssertionError: if outputs from ONNX model and PyTorch model are not
equal up to specified precision.
ValueError: if arguments provided are invalid.
"""
pt_outs, _ = torch.jit._flatten(pt_outs)
pt_outs = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
assert len(ort_outs) == len(
pt_outs
), f"Number of outputs differ ONNX runtime: ({len(ort_outs)}) PyTorch: ({len(pt_outs)})"
if acceptable_error_percentage and (
acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
):
raise ValueError(
"If set, acceptable_error_percentage should be between 0.0 and 1.0"
)
for ort_out, pt_out in zip(ort_outs, pt_outs):
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
if not check_shape:
# Allow different but broadcastable output shapes.
ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
torch.testing.assert_close(
ort_out,
pt_out,
rtol=rtol,
atol=atol,
check_dtype=check_dtype,
equal_nan=True,
)
try:
# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
if not check_shape:
# Allow different but broadcastable output shapes.
ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
torch.testing.assert_close(
ort_out,
pt_out,
rtol=rtol,
atol=atol,
check_dtype=check_dtype,
equal_nan=True,
)
except AssertionError as e:
if acceptable_error_percentage:
error_percentage = 1 - np.sum(
np.isclose(ort_out, pt_out, rtol=rtol, atol=atol)
) / np.prod(ort_out.shape)
if error_percentage <= acceptable_error_percentage:
warnings.warn(
f"Suppressed AssertionError:\n{e}.\n"
f"Error percentage {error_percentage} "
f"within acceptable range {acceptable_error_percentage}."
)
continue
raise
def _prepare_input_for_pytorch(args, kwargs):
@ -243,6 +280,7 @@ def _compare_ort_pytorch_model(
atol,
check_shape,
check_dtype,
accetable_error_persentage: Optional[float],
):
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
@ -265,7 +303,13 @@ def _compare_ort_pytorch_model(
ort_outs = _run_ort(ort_session, ort_inputs)
_compare_ort_pytorch_outputs(
ort_outs, pt_outs, rtol, atol, check_shape, check_dtype
ort_outs,
pt_outs,
rtol,
atol,
check_shape,
check_dtype,
accetable_error_persentage,
)
compare_ort_pytorch_model_with_input(input_args, input_kwargs)
@ -548,6 +592,7 @@ def verify(
ort_providers: Sequence[str] = _ORT_PROVIDERS,
rtol: float = 0.001,
atol: float = 1e-7,
acceptable_error_percentage: Optional[float] = None,
**_,
):
"""Verify model export to ONNX with ONNX Runtime.
@ -586,10 +631,13 @@ def verify(
ort_providers (sequence, optional): ONNX Runtime providers to use.
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
acceptable_error_percentage (float, optional): acceptable percentage of element mismatches in comparison.
It should be a float of value between 0.0 and 1.0.
Raises:
AssertionError: if outputs from ONNX model and PyTorch model are not
equal up to specified precision.
ValueError: if arguments provided are invalid.
"""
if training == torch.onnx.TrainingMode.TRAINING:
model.train()
@ -634,4 +682,5 @@ def verify(
atol,
check_shape,
check_dtype,
acceptable_error_percentage,
)