Revert "[dynamo] handle fullgraph toggle using nested torch.compile (#155166)"

This reverts commit 24dc33b37b.

Reverted https://github.com/pytorch/pytorch/pull/155166 on behalf of https://github.com/ezyang due to All of this is responsible for regression, see https://github.com/pytorch/pytorch/pull/156561 ([comment](https://github.com/pytorch/pytorch/pull/154283#issuecomment-2994242583))
This commit is contained in:
PyTorch MergeBot 2025-06-22 14:22:07 +00:00
parent f1331f3f1b
commit ee3d9969cc
4 changed files with 9 additions and 64 deletions

View File

@ -1040,11 +1040,11 @@ If the above doesn't work, please subtmit an issue to GitHub.
self.assertEqual(cnts.frame_count, 2) self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4) self.assertEqual(cnts.op_count, 4)
cnts.clear() try:
torch._dynamo.reset()
fn3(torch.randn(4, 5)) fn3(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 2) self.assertFalse(True)
self.assertEqual(cnts.op_count, 4) except torch._dynamo.exc.Unsupported as e:
self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e))
def test_disable_optimize(self): def test_disable_optimize(self):
cnt = torch._dynamo.testing.CompileCounter() cnt = torch._dynamo.testing.CompileCounter()
@ -1903,34 +1903,6 @@ If the above doesn't work, please subtmit an issue to GitHub.
with self.assertRaises(Unsupported): with self.assertRaises(Unsupported):
f2(inp) f2(inp)
def test_nested_compile_fullgraph(self):
inp = torch.ones(3)
@torch.compile(backend="eager", fullgraph=True)
def inner_f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(backend="eager", fullgraph=False)
def f1(x):
return inner_f1(x)
with self.assertRaises(Unsupported):
f1(inp)
@torch.compile(backend="eager", fullgraph=False)
def inner_f2(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(backend="eager", fullgraph=True)
def f2(x):
return inner_f2(x)
self.assertEqual(f2(inp), inp + 3)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -3050,7 +3050,6 @@ def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
class TestUnbacked(TestCase): class TestUnbacked(TestCase):
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"]) @parametrize("backend", ["inductor", "eager"])
def test_deferred_neq_assert(self, backend): def test_deferred_neq_assert(self, backend):
@ -3098,7 +3097,6 @@ class TestUnbacked(TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51])) func(torch.rand(2, 50), torch.tensor([51]))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"]) @parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_or_assert(self, backend): def test_deferred_sym_or_assert(self, backend):
@ -3120,7 +3118,6 @@ class TestUnbacked(TestCase):
self.assertTrue(has_free_symbols(sympy.sympify("a*2"))) self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
self.assertTrue(has_free_symbols(sympy.sympify("a+b"))) self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"]) @parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_eq_assert(self, backend): def test_deferred_sym_eq_assert(self, backend):

View File

@ -528,7 +528,7 @@ class _TorchDynamoContext:
patch_fn=nothing, patch_fn=nothing,
first_ctx=False, first_ctx=False,
*, *,
error_on_graph_break=False, error_on_graph_break=None,
export=False, export=False,
dynamic=None, dynamic=None,
compiler_config=None, compiler_config=None,
@ -737,9 +737,7 @@ class _TorchDynamoContext:
_maybe_set_eval_frame(prior) _maybe_set_eval_frame(prior)
# hooks to properly handle inlining # hooks to properly handle inlining
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined] compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break)
)
# Save the function pointer to find the original callable while nesting # Save the function pointer to find the original callable while nesting
# of decorators. # of decorators.
@ -799,7 +797,7 @@ class OptimizeContext(_TorchDynamoContext):
backend_ctx_ctor, backend_ctx_ctor,
first_ctx=False, first_ctx=False,
*, *,
error_on_graph_break=False, error_on_graph_break=None,
export=False, export=False,
dynamic=None, dynamic=None,
compiler_config=None, compiler_config=None,
@ -937,7 +935,7 @@ def _optimize_catch_errors(
compile_fn, compile_fn,
hooks: Hooks, hooks: Hooks,
backend_ctx_ctor=null_context, backend_ctx_ctor=null_context,
error_on_graph_break=False, error_on_graph_break=None,
export=False, export=False,
dynamic=None, dynamic=None,
compiler_config=None, compiler_config=None,

View File

@ -227,25 +227,3 @@ def call_accumulate_grad(
[grad], variable, variable.grad, has_post_hooks [grad], variable, variable.grad, has_post_hooks
) )
variable.grad = updated_grad[0] variable.grad = updated_grad[0]
def wrap_inline_with_set_fullgraph(
fn: Callable[_P, _R], fullgraph: bool
) -> Callable[_P, _R]:
# NB: need multiple definitions in order to prevent `fullgraph` from
# being a freevar of wrapper
if fullgraph:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.set_fullgraph(True):
return fn(*args, **kwargs)
else:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.set_fullgraph(False):
return fn(*args, **kwargs)
return wrapper