mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo] handle fullgraph toggle using nested torch.compile (#155166)"
This reverts commit 24dc33b37b.
Reverted https://github.com/pytorch/pytorch/pull/155166 on behalf of https://github.com/ezyang due to All of this is responsible for regression, see https://github.com/pytorch/pytorch/pull/156561 ([comment](https://github.com/pytorch/pytorch/pull/154283#issuecomment-2994242583))
This commit is contained in:
parent
f1331f3f1b
commit
ee3d9969cc
|
|
@ -1040,11 +1040,11 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
self.assertEqual(cnts.op_count, 4)
|
self.assertEqual(cnts.op_count, 4)
|
||||||
|
|
||||||
cnts.clear()
|
try:
|
||||||
torch._dynamo.reset()
|
fn3(torch.randn(4, 5))
|
||||||
fn3(torch.randn(4, 5))
|
self.assertFalse(True)
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
except torch._dynamo.exc.Unsupported as e:
|
||||||
self.assertEqual(cnts.op_count, 4)
|
self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e))
|
||||||
|
|
||||||
def test_disable_optimize(self):
|
def test_disable_optimize(self):
|
||||||
cnt = torch._dynamo.testing.CompileCounter()
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
@ -1903,34 +1903,6 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
||||||
with self.assertRaises(Unsupported):
|
with self.assertRaises(Unsupported):
|
||||||
f2(inp)
|
f2(inp)
|
||||||
|
|
||||||
def test_nested_compile_fullgraph(self):
|
|
||||||
inp = torch.ones(3)
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def inner_f1(x):
|
|
||||||
x = x + 1
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
return x + 2
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=False)
|
|
||||||
def f1(x):
|
|
||||||
return inner_f1(x)
|
|
||||||
|
|
||||||
with self.assertRaises(Unsupported):
|
|
||||||
f1(inp)
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=False)
|
|
||||||
def inner_f2(x):
|
|
||||||
x = x + 1
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
return x + 2
|
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
|
||||||
def f2(x):
|
|
||||||
return inner_f2(x)
|
|
||||||
|
|
||||||
self.assertEqual(f2(inp), inp + 3)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -3050,7 +3050,6 @@ def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
|
||||||
|
|
||||||
|
|
||||||
class TestUnbacked(TestCase):
|
class TestUnbacked(TestCase):
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
@parametrize("backend", ["inductor", "eager"])
|
@parametrize("backend", ["inductor", "eager"])
|
||||||
def test_deferred_neq_assert(self, backend):
|
def test_deferred_neq_assert(self, backend):
|
||||||
|
|
@ -3098,7 +3097,6 @@ class TestUnbacked(TestCase):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
func(torch.rand(2, 50), torch.tensor([51]))
|
func(torch.rand(2, 50), torch.tensor([51]))
|
||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
@parametrize("backend", ["inductor", "eager"])
|
@parametrize("backend", ["inductor", "eager"])
|
||||||
def test_deferred_sym_or_assert(self, backend):
|
def test_deferred_sym_or_assert(self, backend):
|
||||||
|
|
@ -3120,7 +3118,6 @@ class TestUnbacked(TestCase):
|
||||||
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
|
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
|
||||||
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
|
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
|
||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
@parametrize("backend", ["inductor", "eager"])
|
@parametrize("backend", ["inductor", "eager"])
|
||||||
def test_deferred_sym_eq_assert(self, backend):
|
def test_deferred_sym_eq_assert(self, backend):
|
||||||
|
|
|
||||||
|
|
@ -528,7 +528,7 @@ class _TorchDynamoContext:
|
||||||
patch_fn=nothing,
|
patch_fn=nothing,
|
||||||
first_ctx=False,
|
first_ctx=False,
|
||||||
*,
|
*,
|
||||||
error_on_graph_break=False,
|
error_on_graph_break=None,
|
||||||
export=False,
|
export=False,
|
||||||
dynamic=None,
|
dynamic=None,
|
||||||
compiler_config=None,
|
compiler_config=None,
|
||||||
|
|
@ -737,9 +737,7 @@ class _TorchDynamoContext:
|
||||||
_maybe_set_eval_frame(prior)
|
_maybe_set_eval_frame(prior)
|
||||||
|
|
||||||
# hooks to properly handle inlining
|
# hooks to properly handle inlining
|
||||||
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
|
compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
|
||||||
external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the function pointer to find the original callable while nesting
|
# Save the function pointer to find the original callable while nesting
|
||||||
# of decorators.
|
# of decorators.
|
||||||
|
|
@ -799,7 +797,7 @@ class OptimizeContext(_TorchDynamoContext):
|
||||||
backend_ctx_ctor,
|
backend_ctx_ctor,
|
||||||
first_ctx=False,
|
first_ctx=False,
|
||||||
*,
|
*,
|
||||||
error_on_graph_break=False,
|
error_on_graph_break=None,
|
||||||
export=False,
|
export=False,
|
||||||
dynamic=None,
|
dynamic=None,
|
||||||
compiler_config=None,
|
compiler_config=None,
|
||||||
|
|
@ -937,7 +935,7 @@ def _optimize_catch_errors(
|
||||||
compile_fn,
|
compile_fn,
|
||||||
hooks: Hooks,
|
hooks: Hooks,
|
||||||
backend_ctx_ctor=null_context,
|
backend_ctx_ctor=null_context,
|
||||||
error_on_graph_break=False,
|
error_on_graph_break=None,
|
||||||
export=False,
|
export=False,
|
||||||
dynamic=None,
|
dynamic=None,
|
||||||
compiler_config=None,
|
compiler_config=None,
|
||||||
|
|
|
||||||
|
|
@ -227,25 +227,3 @@ def call_accumulate_grad(
|
||||||
[grad], variable, variable.grad, has_post_hooks
|
[grad], variable, variable.grad, has_post_hooks
|
||||||
)
|
)
|
||||||
variable.grad = updated_grad[0]
|
variable.grad = updated_grad[0]
|
||||||
|
|
||||||
|
|
||||||
def wrap_inline_with_set_fullgraph(
|
|
||||||
fn: Callable[_P, _R], fullgraph: bool
|
|
||||||
) -> Callable[_P, _R]:
|
|
||||||
# NB: need multiple definitions in order to prevent `fullgraph` from
|
|
||||||
# being a freevar of wrapper
|
|
||||||
if fullgraph:
|
|
||||||
|
|
||||||
@functools.wraps(fn)
|
|
||||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
||||||
with torch._dynamo.set_fullgraph(True):
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@functools.wraps(fn)
|
|
||||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
||||||
with torch._dynamo.set_fullgraph(False):
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user