[ca] support higher order gradients (create_graph=True) (#153222)

Adds create_graph support if you don't compile or compile only with torch.compile(backend="eager").

Using a backend that uses AOTDispatch produces a post-dispatch AOT backward, where its double backward will be silently incorrect if the forward trace involved any ops that are not composite implicit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153222
Approved by: https://github.com/jansel
ghstack dependencies: #153193
This commit is contained in:
Simon Fan 2025-05-12 23:15:33 -07:00 committed by PyTorch MergeBot
parent 37efaf4af9
commit a80eb84a5f
6 changed files with 90 additions and 32 deletions

View File

@ -4108,6 +4108,57 @@ class CompiledAutograd1(torch.nn.Module):
):
fn()
def test_higher_order_gradients(self):
def f(x):
return x**3
def fn(fwd_compiler, ca_compiler):
torch.manual_seed(123)
x = torch.tensor(2.0, requires_grad=True)
first, second, third, fourth = None, None, None, None
try:
with compiled_autograd._enable(ca_compiler):
first = torch.autograd.grad(
fwd_compiler(f)(x), x, create_graph=True
)[0]
second = torch.autograd.grad(first, x, create_graph=True)[0]
third = torch.autograd.grad(second, x, create_graph=True)[0]
fourth = torch.autograd.grad(third, x, create_graph=True)[0]
except RuntimeError as e:
assert "does not currently support higher order gradients" in str(e)
return (first, second, third, fourth)
return (first, second, third, fourth)
def eager():
return torch.compile(backend="eager")
def aot_eager():
return torch.compile(backend="aot_eager")
# Without AOTAutograd, no problem
first, second, third, fourth = fn(eager(), eager())
self.assertEqual(counters["compiled_autograd"]["captures"], 4)
self.assertEqual(first, 12) # 3x^2
self.assertEqual(second, 12) # 6x
self.assertEqual(third, 6) # 6
self.assertEqual(fourth, 0)
# and should cache hit
counters.clear()
_ = fn(eager(), eager())
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
torch._dynamo.reset()
# With AOTAutograd, can't create_graph
first, second, third, fourth = fn(aot_eager(), aot_eager())
self.assertIsNone(second)
first, second, third, fourth = fn(aot_eager(), eager())
self.assertIsNone(second)
first, second, third, fourth = fn(eager(), aot_eager())
self.assertIsNone(third)
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent
@ -4227,6 +4278,10 @@ known_graph_breaks_tests = {
"test_prehook_ordering", # retains_grad_hooks
"test_will_engine_execute_node", # retains_grad_hooks
"test_backward_to_node", # retains_grad_hooks
"test_backward_with_nonleaf_inputs", # retains_grad_hook on non-leaf input
"test_create_graph_and_full_backward_hook_cycle", # _pack_with_none
"test_full_backward_hook_double_backward", # _pack_with_none
"test_grad_mode_restored_reentrant", # assertTrue
}
test_contexts = {
@ -4246,7 +4301,6 @@ skipped_tests = {
known_failing_tests = {
# Category: Compiled autograd
"test_grad_mode_restored_reentrant", # create_graph
"test_reentrant_with_callbacks_both_depths", # queue_callback
"test_reentrant_with_callbacks_depth_0", # queue_callback
"test_reentrant_with_callbacks_depth_1", # queue_callback
@ -4254,34 +4308,13 @@ known_failing_tests = {
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error
"test_accumulate_grad", # create_graph
"test_anomaly_assign_parent_cleanup", # create_graph
"test_backward_create_graph_warns", # create_graph
"test_backward_with_nonleaf_inputs", # create_graph
"test_create_graph_and_full_backward_hook_cycle", # create_graph
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
"test_custom_autograd_repeated_grad_grad", # create_graph
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
"test_custom_function_forward_mode_inplace_checks", # forward AD
"test_custom_function_forward_mode_view_checks", # forward AD
"test_custom_function_forward_mode_wrong_formula", # forward AD
"test_default_saved_tensors_hooks_double_backward", # create_graph
"test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook'
"test_full_backward_hook_double_backward", # create_graph
"test_function", # create_graph
"test_grad", # create_graph
"test_grad_materialize_grads", # create_graph
"test_grad_nonleaf", # create_graph
"test_grad_nonleaf_many_outputs", # create_graph
"test_hessian_vector", # create_graph
"test_inplace_on_view_backward", # create_graph
"test_multi_grad_any_hooks", # register_multi_grad_hook
"test_nested_anomaly_detect_nan", # create_graph
"test_nested_anomaly_printstack_cleanup", # create_graph
"test_once_differentiable", # create_graph
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
"test_select_sum", # create_graph, also needs graph breaks
"test_custom_autograd_no_early_free", # create_graph
"test_custom_function_error", # vjp
"test_custom_function_save_for_forward", # vjp
"test_dont_materialize_grads", # undefined grad
@ -4290,10 +4323,16 @@ known_failing_tests = {
"test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError <built-in method clone
"test_save_output_nr", # output_nr grad passed as None
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
"test_lobpcg", # create_graph
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
"test_grad_nonleaf_register_hook",
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
# Category: Higher Order Gradients
"test_default_saved_tensors_hooks_double_backward", # wrong when pack hook returns non-leaf
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # wrong when pack hook returns non-leaf
"test_nested_anomaly_detect_nan", # nested anomaly
"test_select_sum", # batched gradients
"test_custom_autograd_no_early_free", # batched gradients
"test_lobpcg", # NaNs
# Category: Dynamo (pass when directly running CA graph)
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
@ -4339,8 +4378,14 @@ known_failing_tests = {
"test_anomaly_mode_no_check_nan", # different error messages
"test_anomaly_grad_warnings", # different error messages
"test_anomaly_detect_nan", # fake tensor errors on NaN
"test_once_differentiable", # different node name: CompiledFunctionBackward
"test_function", # different node name: CompiledFunctionBackward
"test_inplace_on_view_backward", # different node name: CompiledFunctionBackward
"test_nested_anomaly_printstack_cleanup", # anomaly NaN error message different
# Uncategorized
"test_not_implemented_grad", # Dynamo changes the types of exceptions
"test_grad", # AOT backward higher order gradients
"test_grad_materialize_grads", # AOT backward higher order gradients
}
if not HAS_CUDA:

View File

@ -830,7 +830,7 @@ class TestAutograd(TestCase):
x_grad, x_grad_clone = compute_grad(create_graph=False)
self.assertEqual(x_grad, x_grad_clone * 2)
# Accumulate out-of-place when create_graph is False
# Accumulate out-of-place when create_graph is True
x_grad, x_grad_clone = compute_grad(create_graph=True)
self.assertEqual(x_grad, x_grad_clone)
@ -9376,10 +9376,14 @@ for shape in [(1,), ()]:
with set_warn_always_context(True):
with warnings.catch_warnings(record=True) as w:
tmp.exp().sum().backward(create_graph=True)
self.assertTrue(len(w) == 1)
self.assertTrue(
"Using backward() with create_graph=True" in str(w[0].message)
)
self.assertTrue(w)
found = 0
for warning in w:
if "Using backward() with create_graph=True" in str(
warning.message
):
found += 1
self.assertEqual(found, 1)
# Remove the backward + create_graph=True cycle
a.grad = None

View File

@ -423,6 +423,13 @@ class AutogradCompilerInstance:
aot_id = CompiledFunction._aot_id
del CompiledFunction
if torch.is_grad_enabled():
for output_alias_info in metadata.output_info:
if output_alias_info.requires_grad:
raise RuntimeError(
"torch.compile does not currently support higher order gradients."
)
@torch._dynamo.allow_in_graph # type: ignore[misc]
def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(

View File

@ -75,12 +75,15 @@ def radians(x):
def accumulate_grad(x, new_grad):
# polyfills according to the Gradient Layout Contract
if new_grad is None:
return
new_grad_strided = torch.empty_like(x)
new_grad_strided.copy_(new_grad)
if x.grad is None:
x.grad = new_grad_strided
elif torch.is_grad_enabled():
x.grad = x.grad + new_grad_strided
else:
x.grad.add_(new_grad_strided)

View File

@ -1326,8 +1326,8 @@ auto Engine::execute(
auto graph_task = std::make_shared<GraphTask>(
/* keep_graph */ keep_graph,
/* create_graph */ create_graph,
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
/* grad_mode */ create_graph,
/* reentrant_depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
/* cpu_ready_queue */ local_ready_queue,
/* graph_roots */ std::move(temp_roots));
@ -1348,8 +1348,6 @@ auto Engine::execute(
if (compiled_autograd != nullptr) {
// see [Note: Compiled Autograd]
TORCH_CHECK(
!create_graph, "compiled_autograd does not support create_graph");
_thread_check.release();
GraphTaskGuard guard(graph_task);
CheckpointValidGuard cpvguard(graph_task);

View File

@ -71,6 +71,7 @@ void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
args.collect(variable);
args.collect(variable.grad());
}
args.collect(GradMode::is_enabled());
const auto& hook = tensor_post_acc_grad_hooks();
if (hook != nullptr) {
hook->compiled_args(args);