mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca] Allow using compiled autograd context managers during backward runtime (#156120)
Added an invariant that nested compiled autograd context managers must exit before their parent context manager. This allows us to defer the thread check. FIXES https://github.com/pytorch/pytorch/issues/152219 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156120 Approved by: https://github.com/jansel ghstack dependencies: #155521, #155480
This commit is contained in:
parent
10d41c7d20
commit
17b38b850e
|
|
@ -785,6 +785,88 @@ main()
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
|
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
|
||||||
|
|
||||||
|
@config.patch(compiled_autograd=True)
|
||||||
|
def test_nested_context_manager(self):
|
||||||
|
def ctx():
|
||||||
|
return compiled_autograd._enable(torch.compile)
|
||||||
|
|
||||||
|
# ok
|
||||||
|
outer = ctx()
|
||||||
|
inner = ctx()
|
||||||
|
outer.__enter__()
|
||||||
|
inner.__enter__()
|
||||||
|
inner.__exit__(None, None, None)
|
||||||
|
outer.__exit__(None, None, None)
|
||||||
|
|
||||||
|
# not ok
|
||||||
|
outer = ctx()
|
||||||
|
inner = ctx()
|
||||||
|
outer.__enter__()
|
||||||
|
inner.__enter__()
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError,
|
||||||
|
"Nested Compiled Autograd Contexts must return before their parent context",
|
||||||
|
):
|
||||||
|
outer.__exit__(None, None, None)
|
||||||
|
|
||||||
|
@config.patch(compiled_autograd=True)
|
||||||
|
def test_nested_compile(self):
|
||||||
|
with torch.library._scoped_library("testlib", "FRAGMENT") as lib:
|
||||||
|
lib.define("square(Tensor x) -> Tensor")
|
||||||
|
|
||||||
|
@torch.library.impl("testlib::square", "CPU")
|
||||||
|
def square_impl(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# nested inference graph compile
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn(x):
|
||||||
|
return x**2
|
||||||
|
|
||||||
|
return fn(x)
|
||||||
|
|
||||||
|
class MyFn(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, x):
|
||||||
|
return torch.ops.testlib.square(x)
|
||||||
|
|
||||||
|
x = torch.tensor([2.0, 3.0], requires_grad=True)
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def fn(x):
|
||||||
|
return MyFn.apply(x)
|
||||||
|
|
||||||
|
fn(x).sum().backward()
|
||||||
|
|
||||||
|
@config.patch(compiled_autograd=True)
|
||||||
|
def test_no_nested_compiled_autograd(self):
|
||||||
|
# We disable CA before entering the CA graph
|
||||||
|
# So re-entrants should be running with the eager autograd engine
|
||||||
|
|
||||||
|
def unrelated_autograd_call():
|
||||||
|
x = torch.randn(20, 20, requires_grad=True)
|
||||||
|
y = torch.randn(20, 20, requires_grad=True)
|
||||||
|
loss = torch.matmul(x, y).sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
class MyFn(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gO):
|
||||||
|
unrelated_autograd_call()
|
||||||
|
return gO
|
||||||
|
|
||||||
|
x = torch.randn(10, 10, requires_grad=True)
|
||||||
|
loss = MyFn.apply(x).sum()
|
||||||
|
|
||||||
|
torch.compile(lambda: loss.backward(create_graph=True))()
|
||||||
|
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||||
|
|
||||||
def test_multiple_torch_compile(self):
|
def test_multiple_torch_compile(self):
|
||||||
model = torch.nn.Sequential(
|
model = torch.nn.Sequential(
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 4),
|
||||||
|
|
|
||||||
|
|
@ -1382,6 +1382,8 @@ in_compiled_autograd_region = False
|
||||||
|
|
||||||
active_disable_ctx = False
|
active_disable_ctx = False
|
||||||
|
|
||||||
|
depth = 0
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
||||||
|
|
@ -1437,6 +1439,9 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
||||||
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
|
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
|
||||||
global compiled_autograd_enabled
|
global compiled_autograd_enabled
|
||||||
compiled_autograd_enabled = True
|
compiled_autograd_enabled = True
|
||||||
|
global depth
|
||||||
|
prior_depth = depth
|
||||||
|
depth += 1
|
||||||
try:
|
try:
|
||||||
with torch.autograd.set_multithreading_enabled(False):
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
yield
|
yield
|
||||||
|
|
@ -1446,6 +1451,10 @@ def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
||||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||||||
prior_compiler, prior_dynamic
|
prior_compiler, prior_dynamic
|
||||||
)
|
)
|
||||||
|
depth -= 1
|
||||||
|
assert depth == prior_depth, (
|
||||||
|
"Nested Compiled Autograd Contexts must return before their parent context"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
|
||||||
|
|
@ -75,17 +75,17 @@ inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) {
|
||||||
std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr;
|
std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr;
|
||||||
#define COMPILED_AUTOGRAD_POISON \
|
#define COMPILED_AUTOGRAD_POISON \
|
||||||
reinterpret_cast<Engine::compiled_autograd_fn>(1)
|
reinterpret_cast<Engine::compiled_autograd_fn>(1)
|
||||||
std::atomic<int32_t> num_threads_in_backwards;
|
std::atomic<int32_t> num_threads_in_compiled_autograd;
|
||||||
struct CompiledAutogradThreadingDebugCheck {
|
struct CompiledAutogradThreadingDebugCheck {
|
||||||
CompiledAutogradThreadingDebugCheck() {
|
CompiledAutogradThreadingDebugCheck() {
|
||||||
num_threads_in_backwards++;
|
num_threads_in_compiled_autograd++;
|
||||||
}
|
}
|
||||||
~CompiledAutogradThreadingDebugCheck() {
|
~CompiledAutogradThreadingDebugCheck() {
|
||||||
release();
|
release();
|
||||||
}
|
}
|
||||||
void release() {
|
void release() {
|
||||||
if (std::exchange(incremented, false)) {
|
if (std::exchange(incremented, false)) {
|
||||||
num_threads_in_backwards--;
|
num_threads_in_compiled_autograd--;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1299,8 +1299,6 @@ auto Engine::execute(
|
||||||
"your parameters to None after use to break the cycle and avoid the leak.");
|
"your parameters to None after use to break the cycle and avoid the leak.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allows us to assert no other threads are in backwards
|
|
||||||
CompiledAutogradThreadingDebugCheck _thread_check;
|
|
||||||
auto compiled_autograd = the_compiled_autograd.load();
|
auto compiled_autograd = the_compiled_autograd.load();
|
||||||
TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON);
|
TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON);
|
||||||
|
|
||||||
|
|
@ -1347,6 +1345,11 @@ auto Engine::execute(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (compiled_autograd != nullptr) {
|
if (compiled_autograd != nullptr) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
num_threads_in_compiled_autograd.load() == 0,
|
||||||
|
"Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call.");
|
||||||
|
// Allows us to assert no other threads are in backwards
|
||||||
|
CompiledAutogradThreadingDebugCheck _thread_check;
|
||||||
// see [Note: Compiled Autograd]
|
// see [Note: Compiled Autograd]
|
||||||
_thread_check.release();
|
_thread_check.release();
|
||||||
GraphTaskGuard guard(graph_task);
|
GraphTaskGuard guard(graph_task);
|
||||||
|
|
@ -1495,8 +1498,8 @@ void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) {
|
||||||
}
|
}
|
||||||
auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON);
|
auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON,
|
prior != COMPILED_AUTOGRAD_POISON,
|
||||||
"compiled_autograd._enable() requires no threads in backwards()");
|
"compiled_autograd._enable() does not support multiple Python threads");
|
||||||
the_compiled_autograd.store(fn);
|
the_compiled_autograd.store(fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user