mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo, nested graph breaks] add disable_nested_graph_breaks decorator/context manager (#166477)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166477 Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007 ghstack dependencies: #166476
This commit is contained in:
parent
797cd80b26
commit
1dec8a67a8
|
|
@ -841,6 +841,39 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
|
|||
)
|
||||
)
|
||||
|
||||
def test_disable_nested_graph_breaks(self):
|
||||
global f1, f2, f3, f4, f5
|
||||
|
||||
def f1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 4) + 8
|
||||
|
||||
# NOTE since the disable_nested_graph_breaks decorator is implemented as a
|
||||
# context manager, we don't need to separately test context manager usage.
|
||||
@torch._dynamo.disable_nested_graph_breaks
|
||||
def f3(x):
|
||||
return f2(x + 16) + 32
|
||||
|
||||
def f4(x):
|
||||
return f3(x + 64) + 128
|
||||
|
||||
def f5(x):
|
||||
return f4(x + 256) + 512
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
|
||||
x = torch.zeros(3)
|
||||
res = f5(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# 2 frames from each of f5+f4, f3, f2, f1
|
||||
self.assertEqual(cnts.frame_count, 8)
|
||||
self.assertEqual(cnts.op_count, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from .decorators import (
|
|||
allow_in_graph,
|
||||
assume_constant_result,
|
||||
disable,
|
||||
disable_nested_graph_breaks,
|
||||
disallow_in_graph,
|
||||
dont_skip_tracing,
|
||||
error_on_graph_break,
|
||||
|
|
@ -78,6 +79,7 @@ __all__ = [
|
|||
"assume_constant_result",
|
||||
"config",
|
||||
"disable",
|
||||
"disable_nested_graph_breaks",
|
||||
"disallow_in_graph",
|
||||
"dont_skip_tracing",
|
||||
"export",
|
||||
|
|
|
|||
|
|
@ -890,6 +890,7 @@ _allowed_config_patches = (
|
|||
"allow_unspec_int_on_nn_module",
|
||||
"skip_torchrec",
|
||||
"dont_skip_tracing",
|
||||
"nested_graph_breaks",
|
||||
)
|
||||
|
||||
from . import config
|
||||
|
|
@ -965,6 +966,26 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any:
|
|||
return ctx
|
||||
|
||||
|
||||
@overload
|
||||
def disable_nested_graph_breaks(fn: None = None) -> DynamoConfigPatchProxy: ...
|
||||
|
||||
|
||||
@overload
|
||||
def disable_nested_graph_breaks(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...
|
||||
|
||||
|
||||
def disable_nested_graph_breaks(fn: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
Context manager/decorator to disable nested graph breaks when tracing
|
||||
this function and any nested functions. Used when nested graph breaks
|
||||
is causing problems.
|
||||
"""
|
||||
ctx = patch_dynamo_config(nested_graph_breaks=False)
|
||||
if fn:
|
||||
return ctx(fn)
|
||||
return ctx
|
||||
|
||||
|
||||
class ErrorOnGraphBreakDecoratorContextManager:
|
||||
def __init__(self, error_on_graph_break: bool) -> None:
|
||||
self.error_on_graph_break = error_on_graph_break
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user