mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
SAC: fix recompute tag propagation for ops with list[tensor] inputs (#152195)
There's an "are we compiling" check in SAC, which we rely on to know when to propagate recompute tags during tracing. This check was a bit brittle, and missed cases where input ops accept list of tensors - I updated it to check if a `FunctionalTensorMode` is active, which should be a 100% reliable way to know if AOTDispatcher is in the middle of running. There is a long-standing followup here around unifying `torch.compiler.is_compiling()` to work in all cases. We should probably just update it to always check if FakeMode/FunctionalMode are active and use it there. This has a bit of BC risk though so I opted for the more local fix to SAC. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152195 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
7c96dd8f0c
commit
5abe74857a
|
|
@ -967,6 +967,49 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
|||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
def gn(x, y):
|
||||
return torch.cat([x, y]).sin()
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
y,
|
||||
use_reentrant=False,
|
||||
context_fn=selective_checkpointing_context_fn,
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True, device=device)
|
||||
y = torch.randn(4, 4, requires_grad=True, device=device)
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[1],
|
||||
ops=[torch.ops.aten.cat.default],
|
||||
)
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[1],
|
||||
ops=[torch.ops.aten.cat.default],
|
||||
)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@unittest.skip(
|
||||
"In-place op support in selective checkpointing + torch.compile "
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from weakref import ReferenceType
|
|||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._functorch._aot_autograd.functional_utils import is_fun
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
|
@ -1153,12 +1152,8 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
|||
|
||||
def _is_compiling(func, args, kwargs):
|
||||
# Check if we are under AOTAutograd tracing
|
||||
# There should probably be a better way to do this...
|
||||
# TODO: unify _is_compiling across all compile stacks
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor) and is_fun(arg):
|
||||
return True
|
||||
return False
|
||||
# Checking that a functional mode is active should always do what we want
|
||||
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None
|
||||
|
||||
|
||||
class _VersionWrapper:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user