[opcheck] Improve error reporting; allow atol/rtol overrides (#146488)

This PR improves opcheck to:
1. directly use torch.testing.assert_close (without a msg override).
   This allows it to print the absolute and relative differences and the
   number of mismatched elements.
2. take in an atol/rtol tolerance (for if someone just wants to use
   opcheck in their testing).

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146488
Approved by: https://github.com/williamwen42
This commit is contained in:
rzou 2025-02-05 08:16:25 -08:00 committed by PyTorch MergeBot
parent 1f6b566d74
commit 98b5d455fd
4 changed files with 71 additions and 12 deletions

View File

@ -363,10 +363,20 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
x = torch.tensor(3.14159 / 3, requires_grad=True)
with self.assertRaisesRegex(
optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
optests.OpCheckError, "eager-mode PyTorch vs AOTDispatcher"
):
torch.library.opcheck(op, (x,), {})
# Test that we can actually see the absolute difference numbers
try:
torch.library.opcheck(op, (x,), {})
except optests.OpCheckError as err:
orig = err.__context__.__context__
self.assertIn("Absolute difference:", str(orig))
# Test atol/rtol overrides
torch.library.opcheck(op, (x,), {}, atol=3, rtol=0.01)
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
def test_opcheck_opinfo(self, device, dtype, op):
for sample_input in op.sample_inputs(

View File

@ -1449,6 +1449,8 @@ def opcheck(
*,
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
raise_exception: bool = True,
atol=None,
rtol=None,
) -> dict[str, str]:
"""Given an operator and some sample arguments, tests if the operator is
registered correctly.
@ -1507,6 +1509,14 @@ def opcheck(
raise_exception: If we should raise an exception on the first
error. If False, we will return a dict with information
on if each test passed or not.
rtol (Optional[float]): Relative tolerance for floating point comparisons.
If specified ``atol`` must also be specified.
If omitted, default values based on the ``dtype`` are selected
(see the table in :func:`torch.testing.assert_close`).
atol (Optional[float]): Absolute tolerance for floating point comparisons.
If specified ``rtol`` must also be specified.
If omitted, default values based on the ``dtype`` are selected
(see the table in :func:`torch.testing.assert_close`).
.. warning::
@ -1552,5 +1562,11 @@ def opcheck(
import torch.testing._internal.optests as optests
return optests.opcheck(
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
op,
args,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception,
rtol=rtol,
atol=atol,
)

View File

@ -35,7 +35,7 @@ def aot_autograd_check(
kwargs,
dynamic,
assert_raises_regex_fn=assert_raises_regex,
assert_equals_fn=torch.testing._comparison.assert_close,
assert_equals_fn=torch.testing.assert_close,
check_gradients=True,
try_check_data_specialization=False,
skip_correctness_check=False):
@ -82,9 +82,9 @@ def aot_autograd_check(
outputs_msg = (
"Outputs of the operator are different in eager-mode PyTorch vs "
"AOTAutograd. This means the operator will have incorrect output "
"AOTDispatcher tracing. This means the operator will have incorrect output "
"underneath torch.compile. This could be because the operator's "
"implementation not traceable or that there is a bug in AOTAutograd."
"implementation not traceable."
)
@ -128,16 +128,21 @@ def _test_aot_autograd_forwards_backwards_helper(
msg = (
"Gradients of the operator are different in eager-mode PyTorch vs "
"AOTAutograd. This means the operator will have incorrect gradients "
"AOTDispatcher. This means the operator will have incorrect gradients "
"underneath torch.compile. This could be because the operator's "
"backward is incorrectly registered or not traceable or that there "
"is a bug in AOTAutograd."
"backward is incorrectly registered or not traceable."
)
compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args)
if not skip_correctness_check:
assert_equals_fn(compiled_out, orig_out, msg=outputs_msg)
assert_equals_fn(compiled_grad, orig_grad, msg=msg)
try:
assert_equals_fn(compiled_out, orig_out)
except Exception as e:
raise type(e)(outputs_msg) from e
try:
assert_equals_fn(compiled_grad, orig_grad)
except Exception as e:
raise type(e)(msg) from e
check(args, ignore_failure=False)

View File

@ -50,6 +50,8 @@ def safe_schema_check(
kwargs: dict[str, Any],
*,
copy_inputs: bool = True,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> Any:
if copy_inputs:
args, kwargs = deepcopy_tensors((args, kwargs))
@ -66,6 +68,8 @@ def safe_autograd_registration_check(
kwargs: dict[str, Any],
*,
copy_inputs: bool = True,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> None:
if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
return
@ -85,6 +89,8 @@ def safe_fake_check(
kwargs: dict[str, Any],
*,
copy_inputs: bool = True,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> None:
if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
return None
@ -100,6 +106,8 @@ def safe_aot_autograd_check(
dynamic: bool,
*,
copy_inputs: bool = True,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> Any:
# NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy
# inputs.
@ -112,7 +120,20 @@ def safe_aot_autograd_check(
# aot_autograd_check runs func(*args, **kwargs) multiple times
# and assumes `func` does not modify its inputs.
return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto")
if rtol and atol:
assert_equals_fn = functools.partial(
torch.testing.assert_close, rtol=rtol, atol=atol
)
else:
assert_equals_fn = torch.testing.assert_close
return aot_autograd_check(
func,
args,
kwargs,
dynamic,
check_gradients="auto",
assert_equals_fn=assert_equals_fn,
)
def deepcopy_tensors(inputs: Any) -> Any:
@ -624,9 +645,16 @@ def opcheck(
*,
test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
raise_exception: bool = True,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> dict[str, str]:
"""See torch.library.opcheck for docstring"""
if (rtol is None) ^ (atol is None):
raise ValueError(
"opcheck(op, ...): if you specify one of rtol/atol, you must specify both"
)
if kwargs is None:
kwargs = {}
if isinstance(op, CustomOpDef):
@ -654,7 +682,7 @@ def opcheck(
for test_util in test_utils:
tester = ALL_TEST_UTILS[test_util]
try:
tester(op, args, kwargs)
tester(op, args, kwargs, rtol=rtol, atol=atol)
results_dict[test_util] = "SUCCESS"
except Exception as ex:
if raise_exception: