SAC: fix recompute tag propagation for ops with list[tensor] inputs (#152195)

There's an "are we compiling" check in SAC, which we rely on to know when to propagate recompute tags during tracing.

This check was a bit brittle, and missed cases where input ops accept list of tensors - I updated it to check if a `FunctionalTensorMode` is active, which should be a 100% reliable way to know if AOTDispatcher is in the middle of running.

There is a long-standing followup here around unifying `torch.compiler.is_compiling()` to work in all cases. We should probably just update it to always check if FakeMode/FunctionalMode are active and use it there. This has a bit of BC risk though so I opted for the more local fix to SAC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152195
Approved by: https://github.com/soulitzer
This commit is contained in:
Brian Hirsh 2025-05-02 08:15:40 -07:00 committed by PyTorch MergeBot
parent 7c96dd8f0c
commit 5abe74857a
2 changed files with 45 additions and 7 deletions

View File

@ -967,6 +967,49 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.cat([x, y]).sin()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[1],
ops=[torch.ops.aten.cat.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[1],
ops=[torch.ops.aten.cat.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skip(
"In-place op support in selective checkpointing + torch.compile "

View File

@ -11,7 +11,6 @@ from weakref import ReferenceType
import torch
import torch.fx.traceback as fx_traceback
from torch._functorch._aot_autograd.functional_utils import is_fun
from torch.utils._pytree import tree_map
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
from torch.utils._python_dispatch import TorchDispatchMode
@ -1153,12 +1152,8 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
def _is_compiling(func, args, kwargs):
# Check if we are under AOTAutograd tracing
# There should probably be a better way to do this...
# TODO: unify _is_compiling across all compile stacks
for arg in args:
if isinstance(arg, torch.Tensor) and is_fun(arg):
return True
return False
# Checking that a functional mode is active should always do what we want
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None
class _VersionWrapper: