[ca] Allow using compiled autograd context managers during backward runtime (#156120)

Added an invariant that nested compiled autograd context managers must exit before their parent context manager. This allows us to defer the thread check.

FIXES https://github.com/pytorch/pytorch/issues/152219

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156120
Approved by: https://github.com/jansel
ghstack dependencies: #155521, #155480
This commit is contained in:
Simon Fan 2025-06-17 12:18:36 -07:00 committed by PyTorch MergeBot
parent 10d41c7d20
commit 17b38b850e
3 changed files with 101 additions and 7 deletions

View File

@ -785,6 +785,88 @@ main()
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 0) self.assertEqual(counters["compiled_autograd"]["captures"], 0)
@config.patch(compiled_autograd=True)
def test_nested_context_manager(self):
def ctx():
return compiled_autograd._enable(torch.compile)
# ok
outer = ctx()
inner = ctx()
outer.__enter__()
inner.__enter__()
inner.__exit__(None, None, None)
outer.__exit__(None, None, None)
# not ok
outer = ctx()
inner = ctx()
outer.__enter__()
inner.__enter__()
with self.assertRaisesRegex(
AssertionError,
"Nested Compiled Autograd Contexts must return before their parent context",
):
outer.__exit__(None, None, None)
@config.patch(compiled_autograd=True)
def test_nested_compile(self):
with torch.library._scoped_library("testlib", "FRAGMENT") as lib:
lib.define("square(Tensor x) -> Tensor")
@torch.library.impl("testlib::square", "CPU")
def square_impl(x: torch.Tensor) -> torch.Tensor:
# nested inference graph compile
@torch.compile(backend="eager")
def fn(x):
return x**2
return fn(x)
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, x):
return torch.ops.testlib.square(x)
x = torch.tensor([2.0, 3.0], requires_grad=True)
@torch.compile
def fn(x):
return MyFn.apply(x)
fn(x).sum().backward()
@config.patch(compiled_autograd=True)
def test_no_nested_compiled_autograd(self):
# We disable CA before entering the CA graph
# So re-entrants should be running with the eager autograd engine
def unrelated_autograd_call():
x = torch.randn(20, 20, requires_grad=True)
y = torch.randn(20, 20, requires_grad=True)
loss = torch.matmul(x, y).sum()
loss.backward()
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, gO):
unrelated_autograd_call()
return gO
x = torch.randn(10, 10, requires_grad=True)
loss = MyFn.apply(x).sum()
torch.compile(lambda: loss.backward(create_graph=True))()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_multiple_torch_compile(self): def test_multiple_torch_compile(self):
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Linear(4, 4), torch.nn.Linear(4, 4),

View File

@ -1382,6 +1382,8 @@ in_compiled_autograd_region = False
active_disable_ctx = False active_disable_ctx = False
depth = 0
@contextlib.contextmanager @contextlib.contextmanager
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
@ -1437,6 +1439,9 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type] torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
global compiled_autograd_enabled global compiled_autograd_enabled
compiled_autograd_enabled = True compiled_autograd_enabled = True
global depth
prior_depth = depth
depth += 1
try: try:
with torch.autograd.set_multithreading_enabled(False): with torch.autograd.set_multithreading_enabled(False):
yield yield
@ -1446,6 +1451,10 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
torch._C._dynamo.compiled_autograd.set_autograd_compiler( torch._C._dynamo.compiled_autograd.set_autograd_compiler(
prior_compiler, prior_dynamic prior_compiler, prior_dynamic
) )
depth -= 1
assert depth == prior_depth, (
"Nested Compiled Autograd Contexts must return before their parent context"
)
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -75,17 +75,17 @@ inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) {
std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr; std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr;
#define COMPILED_AUTOGRAD_POISON \ #define COMPILED_AUTOGRAD_POISON \
reinterpret_cast<Engine::compiled_autograd_fn>(1) reinterpret_cast<Engine::compiled_autograd_fn>(1)
std::atomic<int32_t> num_threads_in_backwards; std::atomic<int32_t> num_threads_in_compiled_autograd;
struct CompiledAutogradThreadingDebugCheck { struct CompiledAutogradThreadingDebugCheck {
CompiledAutogradThreadingDebugCheck() { CompiledAutogradThreadingDebugCheck() {
num_threads_in_backwards++; num_threads_in_compiled_autograd++;
} }
~CompiledAutogradThreadingDebugCheck() { ~CompiledAutogradThreadingDebugCheck() {
release(); release();
} }
void release() { void release() {
if (std::exchange(incremented, false)) { if (std::exchange(incremented, false)) {
num_threads_in_backwards--; num_threads_in_compiled_autograd--;
} }
} }
@ -1299,8 +1299,6 @@ auto Engine::execute(
"your parameters to None after use to break the cycle and avoid the leak."); "your parameters to None after use to break the cycle and avoid the leak.");
} }
// Allows us to assert no other threads are in backwards
CompiledAutogradThreadingDebugCheck _thread_check;
auto compiled_autograd = the_compiled_autograd.load(); auto compiled_autograd = the_compiled_autograd.load();
TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON); TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON);
@ -1347,6 +1345,11 @@ auto Engine::execute(
} }
if (compiled_autograd != nullptr) { if (compiled_autograd != nullptr) {
TORCH_CHECK(
num_threads_in_compiled_autograd.load() == 0,
"Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call.");
// Allows us to assert no other threads are in backwards
CompiledAutogradThreadingDebugCheck _thread_check;
// see [Note: Compiled Autograd] // see [Note: Compiled Autograd]
_thread_check.release(); _thread_check.release();
GraphTaskGuard guard(graph_task); GraphTaskGuard guard(graph_task);
@ -1495,8 +1498,8 @@ void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) {
} }
auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON); auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON);
TORCH_CHECK( TORCH_CHECK(
num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON, prior != COMPILED_AUTOGRAD_POISON,
"compiled_autograd._enable() requires no threads in backwards()"); "compiled_autograd._enable() does not support multiple Python threads");
the_compiled_autograd.store(fn); the_compiled_autograd.store(fn);
} }