diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 3e4fb849494..cbe6a268c52 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2686,7 +2686,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { inp = torch.rand(10, 10, requires_grad=True) out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) - with torch._dynamo.compiled_autograd.enable(torch.compile): + with self.assertRaisesRegex( + RuntimeError, + r"\(e.g. reentrant checkpointing\), this is not supported yet\.", + ), torch._dynamo.compiled_autograd.enable(torch.compile): out.backward() diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 675f99e46fb..e7c5d2414f6 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -525,10 +525,6 @@ def disable(): torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) -def maybe_disable_compiled_autograd(): - return disable() if in_compiled_autograd_region else contextlib.nullcontext() - - # return to starting state of a new process def reset() -> None: compiled_autograd_enable = False diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 991b00e8b50..4792ced32b2 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -822,10 +822,9 @@ def _engine_run_backward( if attach_logging_hooks: unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) try: - with torch._dynamo.compiled_autograd.maybe_disable_compiled_autograd(): - return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass + return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass finally: if attach_logging_hooks: unregister_hooks() # type: ignore[possibly-undefined]