diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index a9bac24b2b8..78210822406 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -785,6 +785,88 @@ main() self.assertEqual(expected, actual) self.assertEqual(counters["compiled_autograd"]["captures"], 0) + @config.patch(compiled_autograd=True) + def test_nested_context_manager(self): + def ctx(): + return compiled_autograd._enable(torch.compile) + + # ok + outer = ctx() + inner = ctx() + outer.__enter__() + inner.__enter__() + inner.__exit__(None, None, None) + outer.__exit__(None, None, None) + + # not ok + outer = ctx() + inner = ctx() + outer.__enter__() + inner.__enter__() + with self.assertRaisesRegex( + AssertionError, + "Nested Compiled Autograd Contexts must return before their parent context", + ): + outer.__exit__(None, None, None) + + @config.patch(compiled_autograd=True) + def test_nested_compile(self): + with torch.library._scoped_library("testlib", "FRAGMENT") as lib: + lib.define("square(Tensor x) -> Tensor") + + @torch.library.impl("testlib::square", "CPU") + def square_impl(x: torch.Tensor) -> torch.Tensor: + # nested inference graph compile + @torch.compile(backend="eager") + def fn(x): + return x**2 + + return fn(x) + + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, x): + return torch.ops.testlib.square(x) + + x = torch.tensor([2.0, 3.0], requires_grad=True) + + @torch.compile + def fn(x): + return MyFn.apply(x) + + fn(x).sum().backward() + + @config.patch(compiled_autograd=True) + def test_no_nested_compiled_autograd(self): + # We disable CA before entering the CA graph + # So re-entrants should be running with the eager autograd engine + + def unrelated_autograd_call(): + x = torch.randn(20, 20, requires_grad=True) + y = torch.randn(20, 20, requires_grad=True) + loss = torch.matmul(x, y).sum() + loss.backward() + + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + unrelated_autograd_call() + return gO + + x = torch.randn(10, 10, requires_grad=True) + loss = MyFn.apply(x).sum() + + torch.compile(lambda: loss.backward(create_graph=True))() + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + def test_multiple_torch_compile(self): model = torch.nn.Sequential( torch.nn.Linear(4, 4), diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index d293311128e..15a865766a1 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1382,6 +1382,8 @@ in_compiled_autograd_region = False active_disable_ctx = False +depth = 0 + @contextlib.contextmanager def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): @@ -1437,6 +1439,9 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type] global compiled_autograd_enabled compiled_autograd_enabled = True + global depth + prior_depth = depth + depth += 1 try: with torch.autograd.set_multithreading_enabled(False): yield @@ -1446,6 +1451,10 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): torch._C._dynamo.compiled_autograd.set_autograd_compiler( prior_compiler, prior_dynamic ) + depth -= 1 + assert depth == prior_depth, ( + "Nested Compiled Autograd Contexts must return before their parent context" + ) @contextlib.contextmanager diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 6f5df7de110..50c6e5682c9 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -75,17 +75,17 @@ inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) { std::atomic the_compiled_autograd = nullptr; #define COMPILED_AUTOGRAD_POISON \ reinterpret_cast(1) -std::atomic num_threads_in_backwards; +std::atomic num_threads_in_compiled_autograd; struct CompiledAutogradThreadingDebugCheck { CompiledAutogradThreadingDebugCheck() { - num_threads_in_backwards++; + num_threads_in_compiled_autograd++; } ~CompiledAutogradThreadingDebugCheck() { release(); } void release() { if (std::exchange(incremented, false)) { - num_threads_in_backwards--; + num_threads_in_compiled_autograd--; } } @@ -1299,8 +1299,6 @@ auto Engine::execute( "your parameters to None after use to break the cycle and avoid the leak."); } - // Allows us to assert no other threads are in backwards - CompiledAutogradThreadingDebugCheck _thread_check; auto compiled_autograd = the_compiled_autograd.load(); TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON); @@ -1347,6 +1345,11 @@ auto Engine::execute( } if (compiled_autograd != nullptr) { + TORCH_CHECK( + num_threads_in_compiled_autograd.load() == 0, + "Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call."); + // Allows us to assert no other threads are in backwards + CompiledAutogradThreadingDebugCheck _thread_check; // see [Note: Compiled Autograd] _thread_check.release(); GraphTaskGuard guard(graph_task); @@ -1495,8 +1498,8 @@ void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) { } auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON); TORCH_CHECK( - num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON, - "compiled_autograd._enable() requires no threads in backwards()"); + prior != COMPILED_AUTOGRAD_POISON, + "compiled_autograd._enable() does not support multiple Python threads"); the_compiled_autograd.store(fn); }