From a80eb84a5f70b1c9ea8432ebc1c558a32a5517b9 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 12 May 2025 23:15:33 -0700 Subject: [PATCH] [ca] support higher order gradients (create_graph=True) (#153222) Adds create_graph support if you don't compile or compile only with torch.compile(backend="eager"). Using a backend that uses AOTDispatch produces a post-dispatch AOT backward, where its double backward will be silently incorrect if the forward trace involved any ops that are not composite implicit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153222 Approved by: https://github.com/jansel ghstack dependencies: #153193 --- test/inductor/test_compiled_autograd.py | 91 ++++++++++++++----- test/test_autograd.py | 14 ++- torch/_dynamo/compiled_autograd.py | 7 ++ torch/_dynamo/polyfills/__init__.py | 3 + torch/csrc/autograd/engine.cpp | 6 +- .../autograd/functions/accumulate_grad.cpp | 1 + 6 files changed, 90 insertions(+), 32 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index b3ea378d307..0bfdf34efae 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -4108,6 +4108,57 @@ class CompiledAutograd1(torch.nn.Module): ): fn() + def test_higher_order_gradients(self): + def f(x): + return x**3 + + def fn(fwd_compiler, ca_compiler): + torch.manual_seed(123) + x = torch.tensor(2.0, requires_grad=True) + first, second, third, fourth = None, None, None, None + try: + with compiled_autograd._enable(ca_compiler): + first = torch.autograd.grad( + fwd_compiler(f)(x), x, create_graph=True + )[0] + second = torch.autograd.grad(first, x, create_graph=True)[0] + third = torch.autograd.grad(second, x, create_graph=True)[0] + fourth = torch.autograd.grad(third, x, create_graph=True)[0] + except RuntimeError as e: + assert "does not currently support higher order gradients" in str(e) + return (first, second, third, fourth) + + return (first, second, third, fourth) + + def eager(): + return torch.compile(backend="eager") + + def aot_eager(): + return torch.compile(backend="aot_eager") + + # Without AOTAutograd, no problem + first, second, third, fourth = fn(eager(), eager()) + self.assertEqual(counters["compiled_autograd"]["captures"], 4) + self.assertEqual(first, 12) # 3x^2 + self.assertEqual(second, 12) # 6x + self.assertEqual(third, 6) # 6 + self.assertEqual(fourth, 0) + # and should cache hit + counters.clear() + _ = fn(eager(), eager()) + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + torch._dynamo.reset() + + # With AOTAutograd, can't create_graph + first, second, third, fourth = fn(aot_eager(), aot_eager()) + self.assertIsNone(second) + + first, second, third, fourth = fn(aot_eager(), eager()) + self.assertIsNone(second) + + first, second, third, fourth = fn(eager(), aot_eager()) + self.assertIsNone(third) + def load_test_module(name): testdir = Path(__file__).absolute().parent.parent @@ -4227,6 +4278,10 @@ known_graph_breaks_tests = { "test_prehook_ordering", # retains_grad_hooks "test_will_engine_execute_node", # retains_grad_hooks "test_backward_to_node", # retains_grad_hooks + "test_backward_with_nonleaf_inputs", # retains_grad_hook on non-leaf input + "test_create_graph_and_full_backward_hook_cycle", # _pack_with_none + "test_full_backward_hook_double_backward", # _pack_with_none + "test_grad_mode_restored_reentrant", # assertTrue } test_contexts = { @@ -4246,7 +4301,6 @@ skipped_tests = { known_failing_tests = { # Category: Compiled autograd - "test_grad_mode_restored_reentrant", # create_graph "test_reentrant_with_callbacks_both_depths", # queue_callback "test_reentrant_with_callbacks_depth_0", # queue_callback "test_reentrant_with_callbacks_depth_1", # queue_callback @@ -4254,34 +4308,13 @@ known_failing_tests = { "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd "test_post_accumulate_grad_hook_ordering", # accuracy error - "test_accumulate_grad", # create_graph - "test_anomaly_assign_parent_cleanup", # create_graph - "test_backward_create_graph_warns", # create_graph - "test_backward_with_nonleaf_inputs", # create_graph - "test_create_graph_and_full_backward_hook_cycle", # create_graph "test_current_graph_task_id", # autograd state already cleared once dynamo is called - "test_custom_autograd_repeated_grad_grad", # create_graph "test_custom_function_forward_mode_forward_is_no_op", # forward AD "test_custom_function_forward_mode_inplace_checks", # forward AD "test_custom_function_forward_mode_view_checks", # forward AD "test_custom_function_forward_mode_wrong_formula", # forward AD - "test_default_saved_tensors_hooks_double_backward", # create_graph "test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook' - "test_full_backward_hook_double_backward", # create_graph - "test_function", # create_graph - "test_grad", # create_graph - "test_grad_materialize_grads", # create_graph - "test_grad_nonleaf", # create_graph - "test_grad_nonleaf_many_outputs", # create_graph - "test_hessian_vector", # create_graph - "test_inplace_on_view_backward", # create_graph "test_multi_grad_any_hooks", # register_multi_grad_hook - "test_nested_anomaly_detect_nan", # create_graph - "test_nested_anomaly_printstack_cleanup", # create_graph - "test_once_differentiable", # create_graph - "test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph - "test_select_sum", # create_graph, also needs graph breaks - "test_custom_autograd_no_early_free", # create_graph "test_custom_function_error", # vjp "test_custom_function_save_for_forward", # vjp "test_dont_materialize_grads", # undefined grad @@ -4290,10 +4323,16 @@ known_failing_tests = { "test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError ( /* keep_graph */ keep_graph, - /* create_graph */ create_graph, - /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1, + /* grad_mode */ create_graph, + /* reentrant_depth */ not_reentrant_backward_call ? 0 : total_depth + 1, /* cpu_ready_queue */ local_ready_queue, /* graph_roots */ std::move(temp_roots)); @@ -1348,8 +1348,6 @@ auto Engine::execute( if (compiled_autograd != nullptr) { // see [Note: Compiled Autograd] - TORCH_CHECK( - !create_graph, "compiled_autograd does not support create_graph"); _thread_check.release(); GraphTaskGuard guard(graph_task); CheckpointValidGuard cpvguard(graph_task); diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index fb1eda43f55..c415d7131b3 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -71,6 +71,7 @@ void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const { args.collect(variable); args.collect(variable.grad()); } + args.collect(GradMode::is_enabled()); const auto& hook = tensor_post_acc_grad_hooks(); if (hook != nullptr) { hook->compiled_args(args);