[dynamo, nested graph breaks] add disable_nested_graph_breaks decorator/context manager (#166477)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166477
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
ghstack dependencies: #166476
This commit is contained in:
William Wen 2025-10-31 00:50:46 +00:00 committed by PyTorch MergeBot
parent 797cd80b26
commit 1dec8a67a8
3 changed files with 56 additions and 0 deletions

View File

@ -841,6 +841,39 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
)
)
def test_disable_nested_graph_breaks(self):
global f1, f2, f3, f4, f5
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def f2(x):
return f1(x + 4) + 8
# NOTE since the disable_nested_graph_breaks decorator is implemented as a
# context manager, we don't need to separately test context manager usage.
@torch._dynamo.disable_nested_graph_breaks
def f3(x):
return f2(x + 16) + 32
def f4(x):
return f3(x + 64) + 128
def f5(x):
return f4(x + 256) + 512
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# 2 frames from each of f5+f4, f3, f2, f1
self.assertEqual(cnts.frame_count, 8)
self.assertEqual(cnts.op_count, 10)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -26,6 +26,7 @@ from .decorators import (
allow_in_graph,
assume_constant_result,
disable,
disable_nested_graph_breaks,
disallow_in_graph,
dont_skip_tracing,
error_on_graph_break,
@ -78,6 +79,7 @@ __all__ = [
"assume_constant_result",
"config",
"disable",
"disable_nested_graph_breaks",
"disallow_in_graph",
"dont_skip_tracing",
"export",

View File

@ -890,6 +890,7 @@ _allowed_config_patches = (
"allow_unspec_int_on_nn_module",
"skip_torchrec",
"dont_skip_tracing",
"nested_graph_breaks",
)
from . import config
@ -965,6 +966,26 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any:
return ctx
@overload
def disable_nested_graph_breaks(fn: None = None) -> DynamoConfigPatchProxy: ...
@overload
def disable_nested_graph_breaks(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...
def disable_nested_graph_breaks(fn: Optional[Any] = None) -> Any:
"""
Context manager/decorator to disable nested graph breaks when tracing
this function and any nested functions. Used when nested graph breaks
is causing problems.
"""
ctx = patch_dynamo_config(nested_graph_breaks=False)
if fn:
return ctx(fn)
return ctx
class ErrorOnGraphBreakDecoratorContextManager:
def __init__(self, error_on_graph_break: bool) -> None:
self.error_on_graph_break = error_on_graph_break