mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca] support higher order gradients (create_graph=True) (#153222)
Adds create_graph support if you don't compile or compile only with torch.compile(backend="eager"). Using a backend that uses AOTDispatch produces a post-dispatch AOT backward, where its double backward will be silently incorrect if the forward trace involved any ops that are not composite implicit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153222 Approved by: https://github.com/jansel ghstack dependencies: #153193
This commit is contained in:
parent
37efaf4af9
commit
a80eb84a5f
|
|
@ -4108,6 +4108,57 @@ class CompiledAutograd1(torch.nn.Module):
|
|||
):
|
||||
fn()
|
||||
|
||||
def test_higher_order_gradients(self):
|
||||
def f(x):
|
||||
return x**3
|
||||
|
||||
def fn(fwd_compiler, ca_compiler):
|
||||
torch.manual_seed(123)
|
||||
x = torch.tensor(2.0, requires_grad=True)
|
||||
first, second, third, fourth = None, None, None, None
|
||||
try:
|
||||
with compiled_autograd._enable(ca_compiler):
|
||||
first = torch.autograd.grad(
|
||||
fwd_compiler(f)(x), x, create_graph=True
|
||||
)[0]
|
||||
second = torch.autograd.grad(first, x, create_graph=True)[0]
|
||||
third = torch.autograd.grad(second, x, create_graph=True)[0]
|
||||
fourth = torch.autograd.grad(third, x, create_graph=True)[0]
|
||||
except RuntimeError as e:
|
||||
assert "does not currently support higher order gradients" in str(e)
|
||||
return (first, second, third, fourth)
|
||||
|
||||
return (first, second, third, fourth)
|
||||
|
||||
def eager():
|
||||
return torch.compile(backend="eager")
|
||||
|
||||
def aot_eager():
|
||||
return torch.compile(backend="aot_eager")
|
||||
|
||||
# Without AOTAutograd, no problem
|
||||
first, second, third, fourth = fn(eager(), eager())
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 4)
|
||||
self.assertEqual(first, 12) # 3x^2
|
||||
self.assertEqual(second, 12) # 6x
|
||||
self.assertEqual(third, 6) # 6
|
||||
self.assertEqual(fourth, 0)
|
||||
# and should cache hit
|
||||
counters.clear()
|
||||
_ = fn(eager(), eager())
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# With AOTAutograd, can't create_graph
|
||||
first, second, third, fourth = fn(aot_eager(), aot_eager())
|
||||
self.assertIsNone(second)
|
||||
|
||||
first, second, third, fourth = fn(aot_eager(), eager())
|
||||
self.assertIsNone(second)
|
||||
|
||||
first, second, third, fourth = fn(eager(), aot_eager())
|
||||
self.assertIsNone(third)
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
|
|
@ -4227,6 +4278,10 @@ known_graph_breaks_tests = {
|
|||
"test_prehook_ordering", # retains_grad_hooks
|
||||
"test_will_engine_execute_node", # retains_grad_hooks
|
||||
"test_backward_to_node", # retains_grad_hooks
|
||||
"test_backward_with_nonleaf_inputs", # retains_grad_hook on non-leaf input
|
||||
"test_create_graph_and_full_backward_hook_cycle", # _pack_with_none
|
||||
"test_full_backward_hook_double_backward", # _pack_with_none
|
||||
"test_grad_mode_restored_reentrant", # assertTrue
|
||||
}
|
||||
|
||||
test_contexts = {
|
||||
|
|
@ -4246,7 +4301,6 @@ skipped_tests = {
|
|||
|
||||
known_failing_tests = {
|
||||
# Category: Compiled autograd
|
||||
"test_grad_mode_restored_reentrant", # create_graph
|
||||
"test_reentrant_with_callbacks_both_depths", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_0", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_1", # queue_callback
|
||||
|
|
@ -4254,34 +4308,13 @@ known_failing_tests = {
|
|||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
|
||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||
"test_accumulate_grad", # create_graph
|
||||
"test_anomaly_assign_parent_cleanup", # create_graph
|
||||
"test_backward_create_graph_warns", # create_graph
|
||||
"test_backward_with_nonleaf_inputs", # create_graph
|
||||
"test_create_graph_and_full_backward_hook_cycle", # create_graph
|
||||
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
|
||||
"test_custom_autograd_repeated_grad_grad", # create_graph
|
||||
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
|
||||
"test_custom_function_forward_mode_inplace_checks", # forward AD
|
||||
"test_custom_function_forward_mode_view_checks", # forward AD
|
||||
"test_custom_function_forward_mode_wrong_formula", # forward AD
|
||||
"test_default_saved_tensors_hooks_double_backward", # create_graph
|
||||
"test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook'
|
||||
"test_full_backward_hook_double_backward", # create_graph
|
||||
"test_function", # create_graph
|
||||
"test_grad", # create_graph
|
||||
"test_grad_materialize_grads", # create_graph
|
||||
"test_grad_nonleaf", # create_graph
|
||||
"test_grad_nonleaf_many_outputs", # create_graph
|
||||
"test_hessian_vector", # create_graph
|
||||
"test_inplace_on_view_backward", # create_graph
|
||||
"test_multi_grad_any_hooks", # register_multi_grad_hook
|
||||
"test_nested_anomaly_detect_nan", # create_graph
|
||||
"test_nested_anomaly_printstack_cleanup", # create_graph
|
||||
"test_once_differentiable", # create_graph
|
||||
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
|
||||
"test_select_sum", # create_graph, also needs graph breaks
|
||||
"test_custom_autograd_no_early_free", # create_graph
|
||||
"test_custom_function_error", # vjp
|
||||
"test_custom_function_save_for_forward", # vjp
|
||||
"test_dont_materialize_grads", # undefined grad
|
||||
|
|
@ -4290,10 +4323,16 @@ known_failing_tests = {
|
|||
"test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError <built-in method clone
|
||||
"test_save_output_nr", # output_nr grad passed as None
|
||||
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
|
||||
"test_lobpcg", # create_graph
|
||||
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
|
||||
"test_grad_nonleaf_register_hook",
|
||||
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
|
||||
# Category: Higher Order Gradients
|
||||
"test_default_saved_tensors_hooks_double_backward", # wrong when pack hook returns non-leaf
|
||||
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # wrong when pack hook returns non-leaf
|
||||
"test_nested_anomaly_detect_nan", # nested anomaly
|
||||
"test_select_sum", # batched gradients
|
||||
"test_custom_autograd_no_early_free", # batched gradients
|
||||
"test_lobpcg", # NaNs
|
||||
# Category: Dynamo (pass when directly running CA graph)
|
||||
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
|
||||
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
|
||||
|
|
@ -4339,8 +4378,14 @@ known_failing_tests = {
|
|||
"test_anomaly_mode_no_check_nan", # different error messages
|
||||
"test_anomaly_grad_warnings", # different error messages
|
||||
"test_anomaly_detect_nan", # fake tensor errors on NaN
|
||||
"test_once_differentiable", # different node name: CompiledFunctionBackward
|
||||
"test_function", # different node name: CompiledFunctionBackward
|
||||
"test_inplace_on_view_backward", # different node name: CompiledFunctionBackward
|
||||
"test_nested_anomaly_printstack_cleanup", # anomaly NaN error message different
|
||||
# Uncategorized
|
||||
"test_not_implemented_grad", # Dynamo changes the types of exceptions
|
||||
"test_grad", # AOT backward higher order gradients
|
||||
"test_grad_materialize_grads", # AOT backward higher order gradients
|
||||
}
|
||||
|
||||
if not HAS_CUDA:
|
||||
|
|
|
|||
|
|
@ -830,7 +830,7 @@ class TestAutograd(TestCase):
|
|||
x_grad, x_grad_clone = compute_grad(create_graph=False)
|
||||
self.assertEqual(x_grad, x_grad_clone * 2)
|
||||
|
||||
# Accumulate out-of-place when create_graph is False
|
||||
# Accumulate out-of-place when create_graph is True
|
||||
x_grad, x_grad_clone = compute_grad(create_graph=True)
|
||||
self.assertEqual(x_grad, x_grad_clone)
|
||||
|
||||
|
|
@ -9376,10 +9376,14 @@ for shape in [(1,), ()]:
|
|||
with set_warn_always_context(True):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
tmp.exp().sum().backward(create_graph=True)
|
||||
self.assertTrue(len(w) == 1)
|
||||
self.assertTrue(
|
||||
"Using backward() with create_graph=True" in str(w[0].message)
|
||||
)
|
||||
self.assertTrue(w)
|
||||
found = 0
|
||||
for warning in w:
|
||||
if "Using backward() with create_graph=True" in str(
|
||||
warning.message
|
||||
):
|
||||
found += 1
|
||||
self.assertEqual(found, 1)
|
||||
|
||||
# Remove the backward + create_graph=True cycle
|
||||
a.grad = None
|
||||
|
|
|
|||
|
|
@ -423,6 +423,13 @@ class AutogradCompilerInstance:
|
|||
aot_id = CompiledFunction._aot_id
|
||||
del CompiledFunction
|
||||
|
||||
if torch.is_grad_enabled():
|
||||
for output_alias_info in metadata.output_info:
|
||||
if output_alias_info.requires_grad:
|
||||
raise RuntimeError(
|
||||
"torch.compile does not currently support higher order gradients."
|
||||
)
|
||||
|
||||
@torch._dynamo.allow_in_graph # type: ignore[misc]
|
||||
def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
|
||||
out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(
|
||||
|
|
|
|||
|
|
@ -75,12 +75,15 @@ def radians(x):
|
|||
|
||||
|
||||
def accumulate_grad(x, new_grad):
|
||||
# polyfills according to the Gradient Layout Contract
|
||||
if new_grad is None:
|
||||
return
|
||||
new_grad_strided = torch.empty_like(x)
|
||||
new_grad_strided.copy_(new_grad)
|
||||
if x.grad is None:
|
||||
x.grad = new_grad_strided
|
||||
elif torch.is_grad_enabled():
|
||||
x.grad = x.grad + new_grad_strided
|
||||
else:
|
||||
x.grad.add_(new_grad_strided)
|
||||
|
||||
|
|
|
|||
|
|
@ -1326,8 +1326,8 @@ auto Engine::execute(
|
|||
|
||||
auto graph_task = std::make_shared<GraphTask>(
|
||||
/* keep_graph */ keep_graph,
|
||||
/* create_graph */ create_graph,
|
||||
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
|
||||
/* grad_mode */ create_graph,
|
||||
/* reentrant_depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
|
||||
/* cpu_ready_queue */ local_ready_queue,
|
||||
/* graph_roots */ std::move(temp_roots));
|
||||
|
||||
|
|
@ -1348,8 +1348,6 @@ auto Engine::execute(
|
|||
|
||||
if (compiled_autograd != nullptr) {
|
||||
// see [Note: Compiled Autograd]
|
||||
TORCH_CHECK(
|
||||
!create_graph, "compiled_autograd does not support create_graph");
|
||||
_thread_check.release();
|
||||
GraphTaskGuard guard(graph_task);
|
||||
CheckpointValidGuard cpvguard(graph_task);
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
|
|||
args.collect(variable);
|
||||
args.collect(variable.grad());
|
||||
}
|
||||
args.collect(GradMode::is_enabled());
|
||||
const auto& hook = tensor_post_acc_grad_hooks();
|
||||
if (hook != nullptr) {
|
||||
hook->compiled_args(args);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user