mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo] handle fullgraph toggle using nested torch.compile (#155166)"
This reverts commit 24dc33b37b.
Reverted https://github.com/pytorch/pytorch/pull/155166 on behalf of https://github.com/ezyang due to All of this is responsible for regression, see https://github.com/pytorch/pytorch/pull/156561 ([comment](https://github.com/pytorch/pytorch/pull/154283#issuecomment-2994242583))
This commit is contained in:
parent
f1331f3f1b
commit
ee3d9969cc
|
|
@ -1040,11 +1040,11 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
|||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
cnts.clear()
|
||||
torch._dynamo.reset()
|
||||
fn3(torch.randn(4, 5))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
try:
|
||||
fn3(torch.randn(4, 5))
|
||||
self.assertFalse(True)
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e))
|
||||
|
||||
def test_disable_optimize(self):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
|
@ -1903,34 +1903,6 @@ If the above doesn't work, please subtmit an issue to GitHub.
|
|||
with self.assertRaises(Unsupported):
|
||||
f2(inp)
|
||||
|
||||
def test_nested_compile_fullgraph(self):
|
||||
inp = torch.ones(3)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def inner_f1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=False)
|
||||
def f1(x):
|
||||
return inner_f1(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
f1(inp)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=False)
|
||||
def inner_f2(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f2(x):
|
||||
return inner_f2(x)
|
||||
|
||||
self.assertEqual(f2(inp), inp + 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -3050,7 +3050,6 @@ def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
|
|||
|
||||
|
||||
class TestUnbacked(TestCase):
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@parametrize("backend", ["inductor", "eager"])
|
||||
def test_deferred_neq_assert(self, backend):
|
||||
|
|
@ -3098,7 +3097,6 @@ class TestUnbacked(TestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
func(torch.rand(2, 50), torch.tensor([51]))
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@parametrize("backend", ["inductor", "eager"])
|
||||
def test_deferred_sym_or_assert(self, backend):
|
||||
|
|
@ -3120,7 +3118,6 @@ class TestUnbacked(TestCase):
|
|||
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
|
||||
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@parametrize("backend", ["inductor", "eager"])
|
||||
def test_deferred_sym_eq_assert(self, backend):
|
||||
|
|
|
|||
|
|
@ -528,7 +528,7 @@ class _TorchDynamoContext:
|
|||
patch_fn=nothing,
|
||||
first_ctx=False,
|
||||
*,
|
||||
error_on_graph_break=False,
|
||||
error_on_graph_break=None,
|
||||
export=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
|
|
@ -737,9 +737,7 @@ class _TorchDynamoContext:
|
|||
_maybe_set_eval_frame(prior)
|
||||
|
||||
# hooks to properly handle inlining
|
||||
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
|
||||
external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break)
|
||||
)
|
||||
compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
|
||||
|
||||
# Save the function pointer to find the original callable while nesting
|
||||
# of decorators.
|
||||
|
|
@ -799,7 +797,7 @@ class OptimizeContext(_TorchDynamoContext):
|
|||
backend_ctx_ctor,
|
||||
first_ctx=False,
|
||||
*,
|
||||
error_on_graph_break=False,
|
||||
error_on_graph_break=None,
|
||||
export=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
|
|
@ -937,7 +935,7 @@ def _optimize_catch_errors(
|
|||
compile_fn,
|
||||
hooks: Hooks,
|
||||
backend_ctx_ctor=null_context,
|
||||
error_on_graph_break=False,
|
||||
error_on_graph_break=None,
|
||||
export=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
|
|
|
|||
|
|
@ -227,25 +227,3 @@ def call_accumulate_grad(
|
|||
[grad], variable, variable.grad, has_post_hooks
|
||||
)
|
||||
variable.grad = updated_grad[0]
|
||||
|
||||
|
||||
def wrap_inline_with_set_fullgraph(
|
||||
fn: Callable[_P, _R], fullgraph: bool
|
||||
) -> Callable[_P, _R]:
|
||||
# NB: need multiple definitions in order to prevent `fullgraph` from
|
||||
# being a freevar of wrapper
|
||||
if fullgraph:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch._dynamo.set_fullgraph(True):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch._dynamo.set_fullgraph(False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user