[dynamo] emit only 1 graph break message on unrecoverable data-dependent assert fail (#150471)

Addresses https://fb.workplace.com/groups/1075192433118967/permalink/1625299684774903/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150471
Approved by: https://github.com/jansel
This commit is contained in:
William Wen 2025-04-01 15:43:29 -07:00 committed by PyTorch MergeBot
parent a8f6b40e36
commit 85df0dc246
2 changed files with 68 additions and 22 deletions

View File

@ -36,6 +36,14 @@ make sure that there is a test for it.
"""
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
class GraphBreakMessagesTest(LoggingTestCase):
def test_dynamic_shape_operator(self):
def fn():
@ -569,19 +577,12 @@ from user code:
)
def test_generic_ctx_mgr_graph_break(self):
class CtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def fn():
with CtxMgr():
with CtxMgr():
with GenericCtxMgr():
with GenericCtxMgr():
pass
with CtxMgr():
with CtxMgr():
with GenericCtxMgr():
with GenericCtxMgr():
pass
torch._dynamo.graph_break()
@ -596,7 +597,7 @@ Graph break under GenericContextWrappingVariable
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(CtxMgr), GenericContextWrappingVariable(CtxMgr)]
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)]
from user code:
@ -834,6 +835,43 @@ User code traceback:
""",
)
@make_logging_test(graph_breaks=True)
def test_assert_failure_in_generic_ctx_mgr(self, records):
def fn(x):
with GenericCtxMgr():
assert x is None
with self.assertRaises(AssertionError):
torch.compile(fn, backend="eager")(torch.randn(3))
# only 1 graph break message
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: skip: from user code at:
File "test_error_messages.py", line N, in fn
assert x is None
""",
)
self.assertExpectedInline(
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
"""\
Data-dependent assertion failed (cannot compile partial graph)
Explanation: Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch._assert()` to raise a hard AssertionError when the check fails. This error will propagate back the user code that called the compiled function (i.e. Dynamo wil not trace any exception handling).
Hint: Remove the assert statement.
Hint: Move the assert statement outside of any context managers in order to graph break with partial graph compilation (if fullgraph=False).
Developer debug context: value: ConstantVariable(bool: False)
from user code:
File "test_error_messages.py", line N, in fn
assert x is None""",
)
def test_no_internal_compiler_stacktrace(self):
def fn():
gn()

View File

@ -590,15 +590,7 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
hints=_hints,
),
)
if not self.should_compile_partial_graph():
unimplemented_v2(
gb_type="Should not compile partial graph (data-dependent branching)",
context="",
explanation="Dynamo has determined when encountering data-dependent "
"branching (e.g. `if my_tensor.item() > 0:`) that it should not "
"compile the partial graph.",
hints=[],
)
assert self.should_compile_partial_graph()
# compile a partial subgraph prefix then jump into user code
if self.maybe_has_backedge():
msg = (
@ -642,8 +634,24 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
if value.is_python_constant():
if bool(value.as_python_constant()):
return self.jump(inst)
else:
elif self.should_compile_partial_graph():
jump_graph_break(self, inst, value)
else:
unimplemented_v2(
gb_type="Data-dependent assertion failed (cannot compile partial graph)",
context=f"value: {value}",
explanation="Dynamo has determined when encountering a data-dependent assert failure "
"that it should not compile the partial graph.",
hints=[
*graph_break_hints.FUNDAMENTAL,
"Use `torch._assert()` to raise a hard AssertionError when the check fails. "
"This error will propagate back the user code "
"that called the compiled function (i.e. Dynamo wil not trace any exception handling).",
"Remove the assert statement.",
"Move the assert statement outside of any context managers in order to graph break with "
"partial graph compilation (if fullgraph=False).",
],
)
# TODO maybe should respect DtoH sync intention of users later??
# Manually insert torch._assert_async instead of python assert and jump over