mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[opcheck] Improve error reporting; allow atol/rtol overrides (#146488)
This PR improves opcheck to: 1. directly use torch.testing.assert_close (without a msg override). This allows it to print the absolute and relative differences and the number of mismatched elements. 2. take in an atol/rtol tolerance (for if someone just wants to use opcheck in their testing). Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/146488 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
1f6b566d74
commit
98b5d455fd
|
|
@ -363,10 +363,20 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
|||
|
||||
x = torch.tensor(3.14159 / 3, requires_grad=True)
|
||||
with self.assertRaisesRegex(
|
||||
optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
|
||||
optests.OpCheckError, "eager-mode PyTorch vs AOTDispatcher"
|
||||
):
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
|
||||
# Test that we can actually see the absolute difference numbers
|
||||
try:
|
||||
torch.library.opcheck(op, (x,), {})
|
||||
except optests.OpCheckError as err:
|
||||
orig = err.__context__.__context__
|
||||
self.assertIn("Absolute difference:", str(orig))
|
||||
|
||||
# Test atol/rtol overrides
|
||||
torch.library.opcheck(op, (x,), {}, atol=3, rtol=0.01)
|
||||
|
||||
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
|
||||
def test_opcheck_opinfo(self, device, dtype, op):
|
||||
for sample_input in op.sample_inputs(
|
||||
|
|
|
|||
|
|
@ -1449,6 +1449,8 @@ def opcheck(
|
|||
*,
|
||||
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
|
||||
raise_exception: bool = True,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
) -> dict[str, str]:
|
||||
"""Given an operator and some sample arguments, tests if the operator is
|
||||
registered correctly.
|
||||
|
|
@ -1507,6 +1509,14 @@ def opcheck(
|
|||
raise_exception: If we should raise an exception on the first
|
||||
error. If False, we will return a dict with information
|
||||
on if each test passed or not.
|
||||
rtol (Optional[float]): Relative tolerance for floating point comparisons.
|
||||
If specified ``atol`` must also be specified.
|
||||
If omitted, default values based on the ``dtype`` are selected
|
||||
(see the table in :func:`torch.testing.assert_close`).
|
||||
atol (Optional[float]): Absolute tolerance for floating point comparisons.
|
||||
If specified ``rtol`` must also be specified.
|
||||
If omitted, default values based on the ``dtype`` are selected
|
||||
(see the table in :func:`torch.testing.assert_close`).
|
||||
|
||||
.. warning::
|
||||
|
||||
|
|
@ -1552,5 +1562,11 @@ def opcheck(
|
|||
import torch.testing._internal.optests as optests
|
||||
|
||||
return optests.opcheck(
|
||||
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
||||
op,
|
||||
args,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def aot_autograd_check(
|
|||
kwargs,
|
||||
dynamic,
|
||||
assert_raises_regex_fn=assert_raises_regex,
|
||||
assert_equals_fn=torch.testing._comparison.assert_close,
|
||||
assert_equals_fn=torch.testing.assert_close,
|
||||
check_gradients=True,
|
||||
try_check_data_specialization=False,
|
||||
skip_correctness_check=False):
|
||||
|
|
@ -82,9 +82,9 @@ def aot_autograd_check(
|
|||
|
||||
outputs_msg = (
|
||||
"Outputs of the operator are different in eager-mode PyTorch vs "
|
||||
"AOTAutograd. This means the operator will have incorrect output "
|
||||
"AOTDispatcher tracing. This means the operator will have incorrect output "
|
||||
"underneath torch.compile. This could be because the operator's "
|
||||
"implementation not traceable or that there is a bug in AOTAutograd."
|
||||
"implementation not traceable."
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -128,16 +128,21 @@ def _test_aot_autograd_forwards_backwards_helper(
|
|||
|
||||
msg = (
|
||||
"Gradients of the operator are different in eager-mode PyTorch vs "
|
||||
"AOTAutograd. This means the operator will have incorrect gradients "
|
||||
"AOTDispatcher. This means the operator will have incorrect gradients "
|
||||
"underneath torch.compile. This could be because the operator's "
|
||||
"backward is incorrectly registered or not traceable or that there "
|
||||
"is a bug in AOTAutograd."
|
||||
"backward is incorrectly registered or not traceable."
|
||||
)
|
||||
|
||||
compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args)
|
||||
if not skip_correctness_check:
|
||||
assert_equals_fn(compiled_out, orig_out, msg=outputs_msg)
|
||||
assert_equals_fn(compiled_grad, orig_grad, msg=msg)
|
||||
try:
|
||||
assert_equals_fn(compiled_out, orig_out)
|
||||
except Exception as e:
|
||||
raise type(e)(outputs_msg) from e
|
||||
try:
|
||||
assert_equals_fn(compiled_grad, orig_grad)
|
||||
except Exception as e:
|
||||
raise type(e)(msg) from e
|
||||
|
||||
check(args, ignore_failure=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ def safe_schema_check(
|
|||
kwargs: dict[str, Any],
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
) -> Any:
|
||||
if copy_inputs:
|
||||
args, kwargs = deepcopy_tensors((args, kwargs))
|
||||
|
|
@ -66,6 +68,8 @@ def safe_autograd_registration_check(
|
|||
kwargs: dict[str, Any],
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
) -> None:
|
||||
if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
|
||||
return
|
||||
|
|
@ -85,6 +89,8 @@ def safe_fake_check(
|
|||
kwargs: dict[str, Any],
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
) -> None:
|
||||
if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
|
||||
return None
|
||||
|
|
@ -100,6 +106,8 @@ def safe_aot_autograd_check(
|
|||
dynamic: bool,
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
) -> Any:
|
||||
# NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy
|
||||
# inputs.
|
||||
|
|
@ -112,7 +120,20 @@ def safe_aot_autograd_check(
|
|||
|
||||
# aot_autograd_check runs func(*args, **kwargs) multiple times
|
||||
# and assumes `func` does not modify its inputs.
|
||||
return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto")
|
||||
if rtol and atol:
|
||||
assert_equals_fn = functools.partial(
|
||||
torch.testing.assert_close, rtol=rtol, atol=atol
|
||||
)
|
||||
else:
|
||||
assert_equals_fn = torch.testing.assert_close
|
||||
return aot_autograd_check(
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
dynamic,
|
||||
check_gradients="auto",
|
||||
assert_equals_fn=assert_equals_fn,
|
||||
)
|
||||
|
||||
|
||||
def deepcopy_tensors(inputs: Any) -> Any:
|
||||
|
|
@ -624,9 +645,16 @@ def opcheck(
|
|||
*,
|
||||
test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
) -> dict[str, str]:
|
||||
"""See torch.library.opcheck for docstring"""
|
||||
|
||||
if (rtol is None) ^ (atol is None):
|
||||
raise ValueError(
|
||||
"opcheck(op, ...): if you specify one of rtol/atol, you must specify both"
|
||||
)
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if isinstance(op, CustomOpDef):
|
||||
|
|
@ -654,7 +682,7 @@ def opcheck(
|
|||
for test_util in test_utils:
|
||||
tester = ALL_TEST_UTILS[test_util]
|
||||
try:
|
||||
tester(op, args, kwargs)
|
||||
tester(op, args, kwargs, rtol=rtol, atol=atol)
|
||||
results_dict[test_util] = "SUCCESS"
|
||||
except Exception as ex:
|
||||
if raise_exception:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user