mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo, nested graph breaks] add disable_nested_graph_breaks decorator/context manager (#166477)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166477 Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007 ghstack dependencies: #166476
This commit is contained in:
parent
797cd80b26
commit
1dec8a67a8
|
|
@ -841,6 +841,39 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_disable_nested_graph_breaks(self):
|
||||||
|
global f1, f2, f3, f4, f5
|
||||||
|
|
||||||
|
def f1(x):
|
||||||
|
x = x + 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
def f2(x):
|
||||||
|
return f1(x + 4) + 8
|
||||||
|
|
||||||
|
# NOTE since the disable_nested_graph_breaks decorator is implemented as a
|
||||||
|
# context manager, we don't need to separately test context manager usage.
|
||||||
|
@torch._dynamo.disable_nested_graph_breaks
|
||||||
|
def f3(x):
|
||||||
|
return f2(x + 16) + 32
|
||||||
|
|
||||||
|
def f4(x):
|
||||||
|
return f3(x + 64) + 128
|
||||||
|
|
||||||
|
def f5(x):
|
||||||
|
return f4(x + 256) + 512
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = f5(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
# 2 frames from each of f5+f4, f3, f2, f1
|
||||||
|
self.assertEqual(cnts.frame_count, 8)
|
||||||
|
self.assertEqual(cnts.op_count, 10)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from .decorators import (
|
||||||
allow_in_graph,
|
allow_in_graph,
|
||||||
assume_constant_result,
|
assume_constant_result,
|
||||||
disable,
|
disable,
|
||||||
|
disable_nested_graph_breaks,
|
||||||
disallow_in_graph,
|
disallow_in_graph,
|
||||||
dont_skip_tracing,
|
dont_skip_tracing,
|
||||||
error_on_graph_break,
|
error_on_graph_break,
|
||||||
|
|
@ -78,6 +79,7 @@ __all__ = [
|
||||||
"assume_constant_result",
|
"assume_constant_result",
|
||||||
"config",
|
"config",
|
||||||
"disable",
|
"disable",
|
||||||
|
"disable_nested_graph_breaks",
|
||||||
"disallow_in_graph",
|
"disallow_in_graph",
|
||||||
"dont_skip_tracing",
|
"dont_skip_tracing",
|
||||||
"export",
|
"export",
|
||||||
|
|
|
||||||
|
|
@ -890,6 +890,7 @@ _allowed_config_patches = (
|
||||||
"allow_unspec_int_on_nn_module",
|
"allow_unspec_int_on_nn_module",
|
||||||
"skip_torchrec",
|
"skip_torchrec",
|
||||||
"dont_skip_tracing",
|
"dont_skip_tracing",
|
||||||
|
"nested_graph_breaks",
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
|
|
@ -965,6 +966,26 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any:
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def disable_nested_graph_breaks(fn: None = None) -> DynamoConfigPatchProxy: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def disable_nested_graph_breaks(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def disable_nested_graph_breaks(fn: Optional[Any] = None) -> Any:
|
||||||
|
"""
|
||||||
|
Context manager/decorator to disable nested graph breaks when tracing
|
||||||
|
this function and any nested functions. Used when nested graph breaks
|
||||||
|
is causing problems.
|
||||||
|
"""
|
||||||
|
ctx = patch_dynamo_config(nested_graph_breaks=False)
|
||||||
|
if fn:
|
||||||
|
return ctx(fn)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
class ErrorOnGraphBreakDecoratorContextManager:
|
class ErrorOnGraphBreakDecoratorContextManager:
|
||||||
def __init__(self, error_on_graph_break: bool) -> None:
|
def __init__(self, error_on_graph_break: bool) -> None:
|
||||||
self.error_on_graph_break = error_on_graph_break
|
self.error_on_graph_break = error_on_graph_break
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user