mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca] Allow using compiled autograd context managers during backward runtime (#156120)
Added an invariant that nested compiled autograd context managers must exit before their parent context manager. This allows us to defer the thread check. FIXES https://github.com/pytorch/pytorch/issues/152219 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156120 Approved by: https://github.com/jansel ghstack dependencies: #155521, #155480
This commit is contained in:
parent
10d41c7d20
commit
17b38b850e
|
|
@ -785,6 +785,88 @@ main()
|
|||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
|
||||
|
||||
@config.patch(compiled_autograd=True)
|
||||
def test_nested_context_manager(self):
|
||||
def ctx():
|
||||
return compiled_autograd._enable(torch.compile)
|
||||
|
||||
# ok
|
||||
outer = ctx()
|
||||
inner = ctx()
|
||||
outer.__enter__()
|
||||
inner.__enter__()
|
||||
inner.__exit__(None, None, None)
|
||||
outer.__exit__(None, None, None)
|
||||
|
||||
# not ok
|
||||
outer = ctx()
|
||||
inner = ctx()
|
||||
outer.__enter__()
|
||||
inner.__enter__()
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"Nested Compiled Autograd Contexts must return before their parent context",
|
||||
):
|
||||
outer.__exit__(None, None, None)
|
||||
|
||||
@config.patch(compiled_autograd=True)
|
||||
def test_nested_compile(self):
|
||||
with torch.library._scoped_library("testlib", "FRAGMENT") as lib:
|
||||
lib.define("square(Tensor x) -> Tensor")
|
||||
|
||||
@torch.library.impl("testlib::square", "CPU")
|
||||
def square_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
# nested inference graph compile
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return x**2
|
||||
|
||||
return fn(x)
|
||||
|
||||
class MyFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x):
|
||||
return torch.ops.testlib.square(x)
|
||||
|
||||
x = torch.tensor([2.0, 3.0], requires_grad=True)
|
||||
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return MyFn.apply(x)
|
||||
|
||||
fn(x).sum().backward()
|
||||
|
||||
@config.patch(compiled_autograd=True)
|
||||
def test_no_nested_compiled_autograd(self):
|
||||
# We disable CA before entering the CA graph
|
||||
# So re-entrants should be running with the eager autograd engine
|
||||
|
||||
def unrelated_autograd_call():
|
||||
x = torch.randn(20, 20, requires_grad=True)
|
||||
y = torch.randn(20, 20, requires_grad=True)
|
||||
loss = torch.matmul(x, y).sum()
|
||||
loss.backward()
|
||||
|
||||
class MyFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
unrelated_autograd_call()
|
||||
return gO
|
||||
|
||||
x = torch.randn(10, 10, requires_grad=True)
|
||||
loss = MyFn.apply(x).sum()
|
||||
|
||||
torch.compile(lambda: loss.backward(create_graph=True))()
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||
|
||||
def test_multiple_torch_compile(self):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(4, 4),
|
||||
|
|
|
|||
|
|
@ -1382,6 +1382,8 @@ in_compiled_autograd_region = False
|
|||
|
||||
active_disable_ctx = False
|
||||
|
||||
depth = 0
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
||||
|
|
@ -1437,6 +1439,9 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
|||
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
|
||||
global compiled_autograd_enabled
|
||||
compiled_autograd_enabled = True
|
||||
global depth
|
||||
prior_depth = depth
|
||||
depth += 1
|
||||
try:
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
yield
|
||||
|
|
@ -1446,6 +1451,10 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
|||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||||
prior_compiler, prior_dynamic
|
||||
)
|
||||
depth -= 1
|
||||
assert depth == prior_depth, (
|
||||
"Nested Compiled Autograd Contexts must return before their parent context"
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
|
|||
|
|
@ -75,17 +75,17 @@ inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) {
|
|||
std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr;
|
||||
#define COMPILED_AUTOGRAD_POISON \
|
||||
reinterpret_cast<Engine::compiled_autograd_fn>(1)
|
||||
std::atomic<int32_t> num_threads_in_backwards;
|
||||
std::atomic<int32_t> num_threads_in_compiled_autograd;
|
||||
struct CompiledAutogradThreadingDebugCheck {
|
||||
CompiledAutogradThreadingDebugCheck() {
|
||||
num_threads_in_backwards++;
|
||||
num_threads_in_compiled_autograd++;
|
||||
}
|
||||
~CompiledAutogradThreadingDebugCheck() {
|
||||
release();
|
||||
}
|
||||
void release() {
|
||||
if (std::exchange(incremented, false)) {
|
||||
num_threads_in_backwards--;
|
||||
num_threads_in_compiled_autograd--;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1299,8 +1299,6 @@ auto Engine::execute(
|
|||
"your parameters to None after use to break the cycle and avoid the leak.");
|
||||
}
|
||||
|
||||
// Allows us to assert no other threads are in backwards
|
||||
CompiledAutogradThreadingDebugCheck _thread_check;
|
||||
auto compiled_autograd = the_compiled_autograd.load();
|
||||
TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON);
|
||||
|
||||
|
|
@ -1347,6 +1345,11 @@ auto Engine::execute(
|
|||
}
|
||||
|
||||
if (compiled_autograd != nullptr) {
|
||||
TORCH_CHECK(
|
||||
num_threads_in_compiled_autograd.load() == 0,
|
||||
"Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call.");
|
||||
// Allows us to assert no other threads are in backwards
|
||||
CompiledAutogradThreadingDebugCheck _thread_check;
|
||||
// see [Note: Compiled Autograd]
|
||||
_thread_check.release();
|
||||
GraphTaskGuard guard(graph_task);
|
||||
|
|
@ -1495,8 +1498,8 @@ void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) {
|
|||
}
|
||||
auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON);
|
||||
TORCH_CHECK(
|
||||
num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON,
|
||||
"compiled_autograd._enable() requires no threads in backwards()");
|
||||
prior != COMPILED_AUTOGRAD_POISON,
|
||||
"compiled_autograd._enable() does not support multiple Python threads");
|
||||
the_compiled_autograd.store(fn);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user