Revert "[ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)"

This reverts commit 4863e5c843.

Reverted https://github.com/pytorch/pytorch/pull/153300 on behalf of https://github.com/malfet due to Looks like it breaks rocm, see fa8543454a/1 ([comment](https://github.com/pytorch/pytorch/pull/153300#issuecomment-2884489459))
This commit is contained in:
PyTorch MergeBot 2025-05-15 16:58:52 +00:00
parent 2327c9eedc
commit 236b08cbf8
16 changed files with 18 additions and 45 deletions

View File

@ -4209,13 +4209,10 @@ def wrap_test_class(orig_cls):
):
dct[name] = unittest.expectedFailure
elif name.startswith("test_"):
backend = lookup_backend(name)
if not HAS_CUDA and backend == "inductor":
continue
ctxs = [
compiled_autograd._enable(
make_compiler_fn(
backend=backend,
backend=lookup_backend(name),
fullgraph=name not in known_graph_breaks_tests,
)
),
@ -4308,21 +4305,6 @@ known_graph_breaks_tests = {
"test_full_backward_hook_double_backward", # _pack_with_none
"test_grad_mode_restored_reentrant", # assertTrue
"test_multi_grad_any_hooks", # register_multi_grad_hook
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks
"test_graph_save_on_cpu", # dynamo disabled
"test_nested_checkpoint_early_stop_False", # dynamo disable
"test_nested_checkpoint_early_stop_True", # dynamo disable
"test_nested_checkpoint_kwargs_early_stop_False", # dynamo disable
"test_nested_checkpoint_kwargs_early_stop_True", # dynamo disable
"test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False", # dynamo disable
"test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True", # dynamo disable
"test_nested_checkpoint_reentrant_backwards_early_stop_False", # dynamo disable
"test_nested_checkpoint_reentrant_backwards_early_stop_True", # dynamo disable
"test_nested_checkpoint_same_graph_early_stop_False", # dynamo disable
"test_nested_checkpoint_same_graph_early_stop_True", # dynamo disable
"test_nested_checkpoint_set_early_stop", # dynamo disable
"test_nested_checkpoint_two_children_early_stop_False", # dynamo disable
"test_nested_checkpoint_two_children_early_stop_True", # dynamo disable
}
test_contexts = {
@ -4347,7 +4329,6 @@ xfail_by_backend = {
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
"test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode not yet implemented
"test_post_accumulate_grad_hook_ordering", # accuracy error
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
@ -4381,10 +4362,6 @@ xfail_by_backend = {
"test_return_duplicate_inplace", # batched gradients
"test_naughty_autograd_function_stashing_ctx", # error not raised
"test_unrelated_inputs", # batched gradients
"test_nested_checkpoint_early_stop_False", # unpack hook grad_fn semantics
"test_nested_checkpoint_early_stop_True", # unpack hook grad_fn semantics
"test_nested_checkpoint_two_children_early_stop_False", # unpack hook grad_fn semantics
"test_nested_checkpoint_two_children_early_stop_True", # unpack hook grad_fn semantics
},
"eager": { # will be run without torch.compiling the CA graph
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
@ -4393,14 +4370,25 @@ xfail_by_backend = {
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
"test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int
"test_setitem", # CopySlices accuracy error
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
"test_checkpoint_detects_non_determinism", # different error
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
"test_checkpointing", # takes very very long
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
"test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors
"test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors
"test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors
"test_unwrap_async_collective_tensor_tangent", # AttributeError: 'PlainTensorMeta' object has no attribute 'attrs'
"test_graph_save_on_cpu", # torch.save should no-op and be recorded in the graph
"test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph
"test_nested_checkpoint_early_stop_False", # AOT backward higher order gradients
},
"aot_eager": { # will be run with torch.compile(backend="eager")
# Category: FakeTensor
@ -4442,9 +4430,6 @@ test_autograd = load_test_module("test_autograd")
test_custom_ops = load_test_module("test_custom_ops")
TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd)
TestNestedCheckpointWithCompiledAutograd = wrap_test_class(
test_autograd.TestNestedCheckpoint
)
TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp)
if torch.distributed.is_available() and HAS_CUDA:
test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile")

View File

@ -78,6 +78,7 @@ from torch.testing._internal.common_utils import (
skipIfWindows,
slowTest,
TestCase,
xfailIfTorchDynamo,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
@ -7429,6 +7430,8 @@ for shape in [(1,), ()]:
self.assertEqual(b_grad, c_grad)
self.assertEqual(b_grad, d_grad)
# PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115")
def test_checkpointing_without_reentrant_dataparallel(self):
"""
Verifies gradient correctness when checkpoint without reentrant autograd
@ -7486,6 +7489,8 @@ for shape in [(1,), ()]:
# should only call hook once
self.assertEqual(count, 1)
# https://github.com/pytorch/pytorch/issues/127115
@xfailIfTorchDynamo
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
"""
Ensures checkpointing without reentrant autograd works with functions

View File

@ -328,7 +328,6 @@ class CheckpointFunction(torch.autograd.Function):
def noop_context_fn():
return contextlib.nullcontext(), contextlib.nullcontext()
# Note: [torch.compile and checkpoint]
# TorchDynamo does not step inside utils.checkpoint function. The flow
# looks likes this
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
@ -1492,8 +1491,6 @@ def _checkpoint_without_reentrant_generator(
had_device_in_fwd = True
fwd_devices, fwd_device_states = get_device_states(*args)
# See Note: [compiled autograd and checkpoint unpack hook]
@torch._disable_dynamo
def recompute_fn(*inputs):
kwargs, *args = inputs
# This will be called later during recomputation. This wrapping enables
@ -1544,17 +1541,3 @@ def _checkpoint_without_reentrant_generator(
)
return
# Note: [compiled autograd and checkpoint unpack hook]
# When tracing via compiled autograd, this hook will be visible to the
# compiler if the forward of this checkpointed region ran in eager.
# If the forward had ran under compile, it would have been wrapped in a
# higher order op. See Note: [torch.compile and checkpoint].
#
# Since we run the recomputation hook under a enable_grad context,
# AOTDispatch will trace a joint graph for this hook, and may
# save different activations than in eager. This conflicts with the
# strict activation count checks in `frame.check_recomputed_tensors_match`.
# So, we disable this hook to force it to recompute eager checkpointed regions
# in eager. This could be removed if we can disable the partitioner for this
# graph segment.