mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] emit only 1 graph break message on unrecoverable data-dependent assert fail (#150471)
Addresses https://fb.workplace.com/groups/1075192433118967/permalink/1625299684774903/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/150471 Approved by: https://github.com/jansel
This commit is contained in:
parent
a8f6b40e36
commit
85df0dc246
|
|
@ -36,6 +36,14 @@ make sure that there is a test for it.
|
|||
"""
|
||||
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
class GraphBreakMessagesTest(LoggingTestCase):
|
||||
def test_dynamic_shape_operator(self):
|
||||
def fn():
|
||||
|
|
@ -569,19 +577,12 @@ from user code:
|
|||
)
|
||||
|
||||
def test_generic_ctx_mgr_graph_break(self):
|
||||
class CtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def fn():
|
||||
with CtxMgr():
|
||||
with CtxMgr():
|
||||
with GenericCtxMgr():
|
||||
with GenericCtxMgr():
|
||||
pass
|
||||
with CtxMgr():
|
||||
with CtxMgr():
|
||||
with GenericCtxMgr():
|
||||
with GenericCtxMgr():
|
||||
pass
|
||||
torch._dynamo.graph_break()
|
||||
|
||||
|
|
@ -596,7 +597,7 @@ Graph break under GenericContextWrappingVariable
|
|||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(CtxMgr), GenericContextWrappingVariable(CtxMgr)]
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
|
||||
from user code:
|
||||
|
|
@ -834,6 +835,43 @@ User code traceback:
|
|||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_assert_failure_in_generic_ctx_mgr(self, records):
|
||||
def fn(x):
|
||||
with GenericCtxMgr():
|
||||
assert x is None
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
|
||||
# only 1 graph break message
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: skip: from user code at:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
assert x is None
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Data-dependent assertion failed (cannot compile partial graph)
|
||||
Explanation: Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph.
|
||||
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
|
||||
Hint: Use `torch._assert()` to raise a hard AssertionError when the check fails. This error will propagate back the user code that called the compiled function (i.e. Dynamo wil not trace any exception handling).
|
||||
Hint: Remove the assert statement.
|
||||
Hint: Move the assert statement outside of any context managers in order to graph break with partial graph compilation (if fullgraph=False).
|
||||
|
||||
Developer debug context: value: ConstantVariable(bool: False)
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
assert x is None""",
|
||||
)
|
||||
|
||||
def test_no_internal_compiler_stacktrace(self):
|
||||
def fn():
|
||||
gn()
|
||||
|
|
|
|||
|
|
@ -590,15 +590,7 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||
hints=_hints,
|
||||
),
|
||||
)
|
||||
if not self.should_compile_partial_graph():
|
||||
unimplemented_v2(
|
||||
gb_type="Should not compile partial graph (data-dependent branching)",
|
||||
context="",
|
||||
explanation="Dynamo has determined when encountering data-dependent "
|
||||
"branching (e.g. `if my_tensor.item() > 0:`) that it should not "
|
||||
"compile the partial graph.",
|
||||
hints=[],
|
||||
)
|
||||
assert self.should_compile_partial_graph()
|
||||
# compile a partial subgraph prefix then jump into user code
|
||||
if self.maybe_has_backedge():
|
||||
msg = (
|
||||
|
|
@ -642,8 +634,24 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||
if value.is_python_constant():
|
||||
if bool(value.as_python_constant()):
|
||||
return self.jump(inst)
|
||||
else:
|
||||
elif self.should_compile_partial_graph():
|
||||
jump_graph_break(self, inst, value)
|
||||
else:
|
||||
unimplemented_v2(
|
||||
gb_type="Data-dependent assertion failed (cannot compile partial graph)",
|
||||
context=f"value: {value}",
|
||||
explanation="Dynamo has determined when encountering a data-dependent assert failure "
|
||||
"that it should not compile the partial graph.",
|
||||
hints=[
|
||||
*graph_break_hints.FUNDAMENTAL,
|
||||
"Use `torch._assert()` to raise a hard AssertionError when the check fails. "
|
||||
"This error will propagate back the user code "
|
||||
"that called the compiled function (i.e. Dynamo wil not trace any exception handling).",
|
||||
"Remove the assert statement.",
|
||||
"Move the assert statement outside of any context managers in order to graph break with "
|
||||
"partial graph compilation (if fullgraph=False).",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO maybe should respect DtoH sync intention of users later??
|
||||
# Manually insert torch._assert_async instead of python assert and jump over
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user