[dynamo] fix set_fullgraph for nested calls (#154782)

- Make the fullgraph argument of set_fullgraph a positional argument
- Fix behavior on nested calls by updating `tracer.error_on_graph_break` in more places. In particular, a tracer's error_on_graph_break is set to the inlined tracer's error_on_graph_break upon the latter's exit. We also track error_on_graph_break in the speculation log now, since if we encounter a nested graph break, we will restart analysis and we need to somehow remember the error_on_graph_break setting after attempting to run the nested function (but we don't actually trace into it in the restart analysis).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154782
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289
This commit is contained in:
William Wen 2025-06-18 17:06:54 -07:00 committed by PyTorch MergeBot
parent 2c372a0502
commit 537b0877a8
4 changed files with 200 additions and 35 deletions

View File

@ -1700,7 +1700,7 @@ If the above doesn't work, please subtmit an issue to GitHub.
@torch.compile(backend=cnts, fullgraph=True)
def f1(x):
x = x + 1
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.set_fullgraph(False):
torch._dynamo.graph_break()
return x + 2
@ -1711,7 +1711,7 @@ If the above doesn't work, please subtmit an issue to GitHub.
@torch.compile(backend=cnts)
def f2(x):
x = x + 1
with torch._dynamo.set_fullgraph(fullgraph=True):
with torch._dynamo.set_fullgraph(True):
torch._dynamo.graph_break()
return x + 2
@ -1721,7 +1721,7 @@ If the above doesn't work, please subtmit an issue to GitHub.
@torch.compile(backend=cnts, fullgraph=True)
def f3(x):
x = x + 1
with torch._dynamo.set_fullgraph(fullgraph=False):
with torch._dynamo.set_fullgraph(False):
torch._dynamo.graph_break()
x = x + 2
torch._dynamo.graph_break()
@ -1739,18 +1739,170 @@ If the above doesn't work, please subtmit an issue to GitHub.
@torch.compile(backend=cnts, fullgraph=True)
def f4(x):
x = x + 1
with torch._dynamo.set_fullgraph(fullgraph=False):
# cause a skipped frame
try:
torch._dynamo.graph_break()
except Exception:
pass
with torch._dynamo.set_fullgraph(False):
torch._dynamo.skip_frame()
return inner_f4(x)
cnts.clear()
self.assertEqual(f4(inp), inp + 7)
self.assertEqual(cnts.frame_count, 2)
def test_set_fullgraph_nested(self):
# set_fullgraph in a nested frame
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
def inner_f5(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
def f5(x):
x = x + 1
return inner_f5(x)
inp = torch.ones(3)
self.assertEqual(f5(inp), inp + 7)
self.assertEqual(cnts.frame_count, 4)
def inner_f6(x):
x = x + 2
with torch._dynamo.set_fullgraph(False):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
def f6(x):
x = x + 1
return inner_f6(x)
cnts.clear()
self.assertEqual(f6(inp), inp + 7)
self.assertEqual(cnts.frame_count, 3)
def inner_f7(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=False)
def f7(x):
x = x + 1
return inner_f7(x)
with self.assertRaises(Unsupported):
f7(inp)
def test_set_fullgraph_nested_with_skip(self):
# set_fullgraph in a nested frame with a skipped frame in between
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
def inner2_f8(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
def inner1_f8(x):
with torch._dynamo.set_fullgraph(False):
torch._dynamo.skip_frame()
return inner2_f8(x)
@torch.compile(backend=cnts, fullgraph=True)
def f8(x):
x = x + 1
return inner1_f8(x)
inp = torch.ones(3)
self.assertEqual(f8(inp), inp + 7)
self.assertEqual(cnts.frame_count, 4)
def inner2_f9(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
torch._dynamo.graph_break()
return x + 4
@torch._dynamo.disable(recursive=False)
def inner1_f9(x):
return inner2_f9(x)
@torch.compile(backend=cnts, fullgraph=False)
def f9(x):
x = x + 1
return inner1_f9(x)
with self.assertRaises(Unsupported):
f9(inp)
# test export with set_fullgraph(False) still errors
def test_set_fullgraph_export(self):
@torch._dynamo.set_fullgraph(False)
def inner(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
def f(x):
x = x + 1
return inner(x)
with self.assertRaises(Unsupported):
torch._dynamo.export(f)(torch.ones(3))
def test_set_fullgraph_nested_deep(self):
cnts = torch._dynamo.testing.CompileCounter()
def inner1_f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2_f1(x):
return inner1_f1(x)
def inner3_f1(x):
with torch._dynamo.set_fullgraph(False):
return inner2_f1(x)
def inner4_f1(x):
return inner3_f1(x)
@torch.compile(backend=cnts, fullgraph=True)
def f1(x):
x = x + 4
return inner4_f1(x)
inp = torch.ones(3)
self.assertEqual(f1(inp), inp + 7)
self.assertEqual(cnts.frame_count, 4)
def inner1_f2(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2_f2(x):
return inner1_f2(x)
def inner3_f2(x):
with torch._dynamo.set_fullgraph(True):
return inner2_f2(x)
def inner4_f2(x):
return inner3_f2(x)
@torch.compile(backend=cnts, fullgraph=False)
def f2(x):
x = x + 4
return inner4_f2(x)
with self.assertRaises(Unsupported):
f2(inp)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1315,6 +1315,11 @@ class ConvertFrame:
)
assert error_on_graph_break is not None
if self._inner_convert._box.error_on_graph_break:
# NOTE we _might_ have to wrap the current in a custom exception
# in order to correctly bubble up to the top-level compile wrapper in
# eval_frame.py. But re-raising seems to work for now because exceptions from tracing
# a nested call that results in a top-level frame compile will be handled by the caller
# as an observed exception - we don't expect that exception to be suppressed.
raise
# These two exception types are "soft" failure, in the sense that

View File

@ -872,6 +872,14 @@ def patch_dynamo_config(
return DynamoConfigPatchProxy(config_patch)
@overload
def dont_skip_tracing(fn: None = None) -> DynamoConfigPatchProxy: ...
@overload
def dont_skip_tracing(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...
def dont_skip_tracing(fn=None):
"""
Context manager/decorator to trace into functions intentionally marked by developers to be skipped
@ -885,23 +893,11 @@ def dont_skip_tracing(fn=None):
return ctx
@overload
def set_fullgraph(
fn: None = None, fullgraph: bool = True
) -> DynamoConfigPatchProxy: ...
@overload
def set_fullgraph(fn: Callable[_P, _R], fullgraph: bool = True) -> Callable[_P, _R]: ...
def set_fullgraph(
fn: Optional[Callable[_P, _R]] = None, fullgraph: bool = True
) -> Union[Callable[_P, _R], DynamoConfigPatchProxy]:
def set_fullgraph(fullgraph: bool) -> DynamoConfigPatchProxy:
"""
Context manager/decorator to toggle fullgraph setting.
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
"""
ctx = patch_dynamo_config(error_on_graph_break=fullgraph)
if fn:
return ctx(fn)
return ctx
return patch_dynamo_config(error_on_graph_break=fullgraph)

View File

@ -208,20 +208,29 @@ class SpeculationEntry:
lineno: int
instruction_pointer: int
inst: Instruction # for debugging only
failed: bool = False
_failed: bool = False
error_on_graph_break: Optional[bool] = None
reason: Optional[GraphCompileReason] = None
def fail_and_restart_analysis(self):
def fail_and_restart_analysis(self, error_on_graph_break: bool):
"""
Start tracing of the current frame over again, and don't take this branch.
"""
self.failed = True
self._failed = True
self.error_on_graph_break = error_on_graph_break
if self.reason is not None:
restart_reason = self.reason.reason
else:
restart_reason = "Unknown fail_and_restart_analysis"
raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
def failed(self, tx):
if self._failed:
assert self.error_on_graph_break is not None
tx.error_on_graph_break = self.error_on_graph_break
return True
return False
@dataclasses.dataclass
class SpeculationLog:
@ -827,7 +836,7 @@ def break_graph_if_unsupported(*, push):
@functools.wraps(inner_fn)
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
speculation = self.speculate()
if speculation.failed:
if speculation.failed(self):
assert speculation.reason is not None
return handle_graph_break(self, inst, speculation.reason)
try:
@ -872,7 +881,7 @@ def break_graph_if_unsupported(*, push):
excp.remove_from_stats()
excp.add_to_stats("graph_break")
speculation.reason = GraphCompileReason(excp.msg, excp.real_stack)
speculation.fail_and_restart_analysis()
speculation.fail_and_restart_analysis(self.error_on_graph_break)
def handle_graph_break(
self: "InstructionTranslatorBase",
@ -1255,7 +1264,7 @@ class InstructionTranslatorBase(
and self.is_non_empty_graph()
):
self.current_speculation = self.speculate()
if self.current_speculation.failed:
if self.current_speculation.failed(self):
return self.step_graph_break(inst)
if self.is_trace_bytecode_log_enabled:
@ -1281,7 +1290,7 @@ class InstructionTranslatorBase(
raise
log.debug("step triggered compile", exc_info=True)
self.current_speculation.fail_and_restart_analysis()
self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break)
if sys.version_info >= (3, 11):
@ -2297,7 +2306,7 @@ class InstructionTranslatorBase(
def STORE_ATTR(self, inst):
speculation = self.speculate()
if speculation.failed:
if speculation.failed(self):
return self.store_attr_graph_break(inst)
val, obj = self.popn(2)
@ -2321,7 +2330,7 @@ class InstructionTranslatorBase(
log.debug("STORE_ATTR triggered compile", exc_info=True)
e.remove_from_stats()
e.add_to_stats("graph_break")
speculation.fail_and_restart_analysis()
speculation.fail_and_restart_analysis(self.error_on_graph_break)
def store_attr_graph_break(self, inst):
log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break")
@ -3922,6 +3931,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
except Exception:
log.debug("FAILED INLINING %s", code)
raise
finally:
parent.error_on_graph_break = self.error_on_graph_break
assert self.symbolic_result is not None
if self.f_globals is parent.f_globals: