mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] fix set_fullgraph for nested calls (#154782)
- Make the fullgraph argument of set_fullgraph a positional argument - Fix behavior on nested calls by updating `tracer.error_on_graph_break` in more places. In particular, a tracer's error_on_graph_break is set to the inlined tracer's error_on_graph_break upon the latter's exit. We also track error_on_graph_break in the speculation log now, since if we encounter a nested graph break, we will restart analysis and we need to somehow remember the error_on_graph_break setting after attempting to run the nested function (but we don't actually trace into it in the restart analysis). Pull Request resolved: https://github.com/pytorch/pytorch/pull/154782 Approved by: https://github.com/jansel ghstack dependencies: #154283, #154289
This commit is contained in:
parent
2c372a0502
commit
537b0877a8
|
|
@ -1700,7 +1700,7 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
|||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f1(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
|
|
@ -1711,7 +1711,7 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
|||
@torch.compile(backend=cnts)
|
||||
def f2(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(fullgraph=True):
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
|
|
@ -1721,7 +1721,7 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
|||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f3(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
torch._dynamo.graph_break()
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
|
|
@ -1739,18 +1739,170 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
|||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f4(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.set_fullgraph(fullgraph=False):
|
||||
# cause a skipped frame
|
||||
try:
|
||||
torch._dynamo.graph_break()
|
||||
except Exception:
|
||||
pass
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
torch._dynamo.skip_frame()
|
||||
return inner_f4(x)
|
||||
|
||||
cnts.clear()
|
||||
self.assertEqual(f4(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_set_fullgraph_nested(self):
|
||||
# set_fullgraph in a nested frame
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.set_fullgraph(False)
|
||||
def inner_f5(x):
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f5(x):
|
||||
x = x + 1
|
||||
return inner_f5(x)
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(f5(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 4)
|
||||
|
||||
def inner_f6(x):
|
||||
x = x + 2
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f6(x):
|
||||
x = x + 1
|
||||
return inner_f6(x)
|
||||
|
||||
cnts.clear()
|
||||
self.assertEqual(f6(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
def inner_f7(x):
|
||||
x = x + 2
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=False)
|
||||
def f7(x):
|
||||
x = x + 1
|
||||
return inner_f7(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
f7(inp)
|
||||
|
||||
def test_set_fullgraph_nested_with_skip(self):
|
||||
# set_fullgraph in a nested frame with a skipped frame in between
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.set_fullgraph(False)
|
||||
def inner2_f8(x):
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
def inner1_f8(x):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
torch._dynamo.skip_frame()
|
||||
return inner2_f8(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f8(x):
|
||||
x = x + 1
|
||||
return inner1_f8(x)
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(f8(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 4)
|
||||
|
||||
def inner2_f9(x):
|
||||
x = x + 2
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch._dynamo.disable(recursive=False)
|
||||
def inner1_f9(x):
|
||||
return inner2_f9(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=False)
|
||||
def f9(x):
|
||||
x = x + 1
|
||||
return inner1_f9(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
f9(inp)
|
||||
|
||||
# test export with set_fullgraph(False) still errors
|
||||
|
||||
def test_set_fullgraph_export(self):
|
||||
@torch._dynamo.set_fullgraph(False)
|
||||
def inner(x):
|
||||
x = x + 2
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
def f(x):
|
||||
x = x + 1
|
||||
return inner(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
torch._dynamo.export(f)(torch.ones(3))
|
||||
|
||||
def test_set_fullgraph_nested_deep(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
def inner1_f1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def inner2_f1(x):
|
||||
return inner1_f1(x)
|
||||
|
||||
def inner3_f1(x):
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
return inner2_f1(x)
|
||||
|
||||
def inner4_f1(x):
|
||||
return inner3_f1(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f1(x):
|
||||
x = x + 4
|
||||
return inner4_f1(x)
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(f1(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 4)
|
||||
|
||||
def inner1_f2(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def inner2_f2(x):
|
||||
return inner1_f2(x)
|
||||
|
||||
def inner3_f2(x):
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
return inner2_f2(x)
|
||||
|
||||
def inner4_f2(x):
|
||||
return inner3_f2(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=False)
|
||||
def f2(x):
|
||||
x = x + 4
|
||||
return inner4_f2(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
f2(inp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1315,6 +1315,11 @@ class ConvertFrame:
|
|||
)
|
||||
assert error_on_graph_break is not None
|
||||
if self._inner_convert._box.error_on_graph_break:
|
||||
# NOTE we _might_ have to wrap the current in a custom exception
|
||||
# in order to correctly bubble up to the top-level compile wrapper in
|
||||
# eval_frame.py. But re-raising seems to work for now because exceptions from tracing
|
||||
# a nested call that results in a top-level frame compile will be handled by the caller
|
||||
# as an observed exception - we don't expect that exception to be suppressed.
|
||||
raise
|
||||
|
||||
# These two exception types are "soft" failure, in the sense that
|
||||
|
|
|
|||
|
|
@ -872,6 +872,14 @@ def patch_dynamo_config(
|
|||
return DynamoConfigPatchProxy(config_patch)
|
||||
|
||||
|
||||
@overload
|
||||
def dont_skip_tracing(fn: None = None) -> DynamoConfigPatchProxy: ...
|
||||
|
||||
|
||||
@overload
|
||||
def dont_skip_tracing(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...
|
||||
|
||||
|
||||
def dont_skip_tracing(fn=None):
|
||||
"""
|
||||
Context manager/decorator to trace into functions intentionally marked by developers to be skipped
|
||||
|
|
@ -885,23 +893,11 @@ def dont_skip_tracing(fn=None):
|
|||
return ctx
|
||||
|
||||
|
||||
@overload
|
||||
def set_fullgraph(
|
||||
fn: None = None, fullgraph: bool = True
|
||||
) -> DynamoConfigPatchProxy: ...
|
||||
|
||||
|
||||
@overload
|
||||
def set_fullgraph(fn: Callable[_P, _R], fullgraph: bool = True) -> Callable[_P, _R]: ...
|
||||
|
||||
|
||||
def set_fullgraph(
|
||||
fn: Optional[Callable[_P, _R]] = None, fullgraph: bool = True
|
||||
) -> Union[Callable[_P, _R], DynamoConfigPatchProxy]:
|
||||
def set_fullgraph(fullgraph: bool) -> DynamoConfigPatchProxy:
|
||||
"""
|
||||
Context manager/decorator to toggle fullgraph setting.
|
||||
|
||||
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
|
||||
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
|
||||
"""
|
||||
ctx = patch_dynamo_config(error_on_graph_break=fullgraph)
|
||||
if fn:
|
||||
return ctx(fn)
|
||||
return ctx
|
||||
return patch_dynamo_config(error_on_graph_break=fullgraph)
|
||||
|
|
|
|||
|
|
@ -208,20 +208,29 @@ class SpeculationEntry:
|
|||
lineno: int
|
||||
instruction_pointer: int
|
||||
inst: Instruction # for debugging only
|
||||
failed: bool = False
|
||||
_failed: bool = False
|
||||
error_on_graph_break: Optional[bool] = None
|
||||
reason: Optional[GraphCompileReason] = None
|
||||
|
||||
def fail_and_restart_analysis(self):
|
||||
def fail_and_restart_analysis(self, error_on_graph_break: bool):
|
||||
"""
|
||||
Start tracing of the current frame over again, and don't take this branch.
|
||||
"""
|
||||
self.failed = True
|
||||
self._failed = True
|
||||
self.error_on_graph_break = error_on_graph_break
|
||||
if self.reason is not None:
|
||||
restart_reason = self.reason.reason
|
||||
else:
|
||||
restart_reason = "Unknown fail_and_restart_analysis"
|
||||
raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
|
||||
|
||||
def failed(self, tx):
|
||||
if self._failed:
|
||||
assert self.error_on_graph_break is not None
|
||||
tx.error_on_graph_break = self.error_on_graph_break
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SpeculationLog:
|
||||
|
|
@ -827,7 +836,7 @@ def break_graph_if_unsupported(*, push):
|
|||
@functools.wraps(inner_fn)
|
||||
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
|
||||
speculation = self.speculate()
|
||||
if speculation.failed:
|
||||
if speculation.failed(self):
|
||||
assert speculation.reason is not None
|
||||
return handle_graph_break(self, inst, speculation.reason)
|
||||
try:
|
||||
|
|
@ -872,7 +881,7 @@ def break_graph_if_unsupported(*, push):
|
|||
excp.remove_from_stats()
|
||||
excp.add_to_stats("graph_break")
|
||||
speculation.reason = GraphCompileReason(excp.msg, excp.real_stack)
|
||||
speculation.fail_and_restart_analysis()
|
||||
speculation.fail_and_restart_analysis(self.error_on_graph_break)
|
||||
|
||||
def handle_graph_break(
|
||||
self: "InstructionTranslatorBase",
|
||||
|
|
@ -1255,7 +1264,7 @@ class InstructionTranslatorBase(
|
|||
and self.is_non_empty_graph()
|
||||
):
|
||||
self.current_speculation = self.speculate()
|
||||
if self.current_speculation.failed:
|
||||
if self.current_speculation.failed(self):
|
||||
return self.step_graph_break(inst)
|
||||
|
||||
if self.is_trace_bytecode_log_enabled:
|
||||
|
|
@ -1281,7 +1290,7 @@ class InstructionTranslatorBase(
|
|||
raise
|
||||
log.debug("step triggered compile", exc_info=True)
|
||||
|
||||
self.current_speculation.fail_and_restart_analysis()
|
||||
self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
|
||||
|
|
@ -2297,7 +2306,7 @@ class InstructionTranslatorBase(
|
|||
|
||||
def STORE_ATTR(self, inst):
|
||||
speculation = self.speculate()
|
||||
if speculation.failed:
|
||||
if speculation.failed(self):
|
||||
return self.store_attr_graph_break(inst)
|
||||
val, obj = self.popn(2)
|
||||
|
||||
|
|
@ -2321,7 +2330,7 @@ class InstructionTranslatorBase(
|
|||
log.debug("STORE_ATTR triggered compile", exc_info=True)
|
||||
e.remove_from_stats()
|
||||
e.add_to_stats("graph_break")
|
||||
speculation.fail_and_restart_analysis()
|
||||
speculation.fail_and_restart_analysis(self.error_on_graph_break)
|
||||
|
||||
def store_attr_graph_break(self, inst):
|
||||
log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break")
|
||||
|
|
@ -3922,6 +3931,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
except Exception:
|
||||
log.debug("FAILED INLINING %s", code)
|
||||
raise
|
||||
finally:
|
||||
parent.error_on_graph_break = self.error_on_graph_break
|
||||
|
||||
assert self.symbolic_result is not None
|
||||
|
||||
if self.f_globals is parent.f_globals:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user