Add early_stop kwarg to torch.utils.checkpoint (#160781)

We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting.

It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper.

Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781
Approved by: https://github.com/tianyu-l
This commit is contained in:
soulitzer 2025-08-20 11:57:15 -07:00 committed by PyTorch MergeBot
parent 4d078cfc4e
commit 1e4dfeeb06
3 changed files with 65 additions and 6 deletions

View File

@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase):
# early stop is enabled. # early stop is enabled.
return clone(x.sin().cos()) return clone(x.sin().cos())
# Test default
# Early stopping is enabled by default # Early stopping is enabled by default
a = torch.tensor(1.0, requires_grad=True) a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False) out = checkpoint(fn, a, use_reentrant=False)
out.backward() out.backward()
self.assertEqual(counter[0], 1) self.assertEqual(counter[0], 1)
# Try using the context manager to set early stopping to False. # Test local setting
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
out.backward()
self.assertEqual(counter[0], 1)
# Test context manager
# Expect early stopping to be disabled for all checkpoints ran under # Expect early stopping to be disabled for all checkpoints ran under
# the context manager, even though context manager is no longer active # the context manager, even though context manager is no longer active
# when backward/recomputation is performed. # when backward/recomputation is performed.
@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase):
a = torch.tensor(1.0, requires_grad=True) a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False): with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False) out = checkpoint(fn, a, use_reentrant=False)
out.backward() out.backward()
self.assertEqual(counter[0], 2) self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)
# Test context manager nesting
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 1)
# Test precedence
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 1)
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self): def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
# Case 1: We have one tensor saved and its the input # Case 1: We have one tensor saved and its the input

View File

@ -79,6 +79,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
user_context_fns = kwargs.pop("context_fn", None) user_context_fns = kwargs.pop("context_fn", None)
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
debug = kwargs.pop("debug", False) debug = kwargs.pop("debug", False)
early_stop = kwargs.pop("early_stop", True)
if kwargs: if kwargs:
raise ValueError( raise ValueError(
@ -103,6 +104,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
context_fns, context_fns,
determinism_check, determinism_check,
debug, debug,
early_stop,
*args, *args,
**kwargs, **kwargs,
) )

View File

@ -347,6 +347,7 @@ def checkpoint(
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = _DEFAULT_DETERMINISM_MODE, determinism_check: str = _DEFAULT_DETERMINISM_MODE,
debug: bool = False, debug: bool = False,
early_stop: bool = True,
**kwargs **kwargs
): ):
r"""Checkpoint a model or part of the model. r"""Checkpoint a model or part of the model.
@ -425,6 +426,9 @@ def checkpoint(
passed as the tuple. For example, in LSTM, if user passes passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the ``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden`` first input as ``activation`` and the second input as ``hidden``
args: tuple containing inputs to the :attr:`function`
Keyword args:
preserve_rng_state(bool, optional): Omit stashing and restoring preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint. Note that under torch.compile, the RNG state during each checkpoint. Note that under torch.compile,
this flag doesn't take effect and we always preserve RNG state. this flag doesn't take effect and we always preserve RNG state.
@ -455,7 +459,11 @@ def checkpoint(
a trace of the operators ran during the original forward computation a trace of the operators ran during the original forward computation
as well as the recomputation. This argument is only supported if as well as the recomputation. This argument is only supported if
``use_reentrant=False``. ``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function` early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
recomputation as soon as it has computed all needed Tensors. This
argument is ignored if ``use_reentrant=True``. Can be overridden
globally using :func:`set_checkpoint_early_stop` context manager.
Default: ``True``.
Returns: Returns:
Output of running :attr:`function` on :attr:`*args` Output of running :attr:`function` on :attr:`*args`
@ -488,7 +496,7 @@ def checkpoint(
return CheckpointFunction.apply(function, preserve, *args) return CheckpointFunction.apply(function, preserve, *args)
else: else:
gen = _checkpoint_without_reentrant_generator( gen = _checkpoint_without_reentrant_generator(
function, preserve, context_fn, determinism_check, debug, *args, **kwargs function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
) )
# Runs pre-forward logic # Runs pre-forward logic
next(gen) next(gen)
@ -731,7 +739,7 @@ def _internal_assert(cond):
# by holder=None. We skip over them. We still save x at (4) (since its holder # by holder=None. We skip over them. We still save x at (4) (since its holder
# is still alive.) # is still alive.)
_enable_checkpoint_early_stop = True _enable_checkpoint_early_stop: Optional[bool] = None
@contextlib.contextmanager @contextlib.contextmanager
@ -1448,6 +1456,7 @@ def _checkpoint_without_reentrant_generator(
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = _DEFAULT_DETERMINISM_MODE, determinism_check: str = _DEFAULT_DETERMINISM_MODE,
debug: bool = False, debug: bool = False,
early_stop: bool = True,
*args, *args,
**kwargs **kwargs
): ):
@ -1475,6 +1484,10 @@ def _checkpoint_without_reentrant_generator(
debug(bool, optional): If ``True``, error messages will also include debug(bool, optional): If ``True``, error messages will also include
a trace of the operators ran during the original forward computation a trace of the operators ran during the original forward computation
as well as the recomputation. as well as the recomputation.
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
recomputation as soon as it has computed all needed Tensors. Can be
overridden globally using :func:`set_checkpoint_early_stop` context
manager. Default: ``True``.
*args: Arguments to pass in to the given ``function``. *args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``. **kwargs: Keyword arguments to pass into the given ``function``.
""" """
@ -1543,7 +1556,7 @@ def _checkpoint_without_reentrant_generator(
new_frame = _CheckpointFrame( new_frame = _CheckpointFrame(
recompute_fn, recompute_fn,
_enable_checkpoint_early_stop, _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop,
unpack_error_cb, unpack_error_cb,
metadata_fn metadata_fn
) )