mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add early_stop kwarg to torch.utils.checkpoint (#160781)
We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting. It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper. Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781 Approved by: https://github.com/tianyu-l
This commit is contained in:
parent
4d078cfc4e
commit
1e4dfeeb06
|
|
@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase):
|
||||||
# early stop is enabled.
|
# early stop is enabled.
|
||||||
return clone(x.sin().cos())
|
return clone(x.sin().cos())
|
||||||
|
|
||||||
|
# Test default
|
||||||
# Early stopping is enabled by default
|
# Early stopping is enabled by default
|
||||||
a = torch.tensor(1.0, requires_grad=True)
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
out = checkpoint(fn, a, use_reentrant=False)
|
out = checkpoint(fn, a, use_reentrant=False)
|
||||||
out.backward()
|
out.backward()
|
||||||
self.assertEqual(counter[0], 1)
|
self.assertEqual(counter[0], 1)
|
||||||
|
|
||||||
# Try using the context manager to set early stopping to False.
|
# Test local setting
|
||||||
|
counter = [0]
|
||||||
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
|
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(counter[0], 2)
|
||||||
|
|
||||||
|
counter = [0]
|
||||||
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
|
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(counter[0], 1)
|
||||||
|
|
||||||
|
# Test context manager
|
||||||
# Expect early stopping to be disabled for all checkpoints ran under
|
# Expect early stopping to be disabled for all checkpoints ran under
|
||||||
# the context manager, even though context manager is no longer active
|
# the context manager, even though context manager is no longer active
|
||||||
# when backward/recomputation is performed.
|
# when backward/recomputation is performed.
|
||||||
|
|
@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase):
|
||||||
a = torch.tensor(1.0, requires_grad=True)
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||||
out = checkpoint(fn, a, use_reentrant=False)
|
out = checkpoint(fn, a, use_reentrant=False)
|
||||||
|
|
||||||
out.backward()
|
out.backward()
|
||||||
self.assertEqual(counter[0], 2)
|
self.assertEqual(counter[0], 2)
|
||||||
|
|
||||||
|
counter = [0]
|
||||||
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
|
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||||
|
out = checkpoint(fn, a, use_reentrant=False)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(counter[0], 1)
|
||||||
|
|
||||||
|
# Test context manager nesting
|
||||||
|
counter = [0]
|
||||||
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
|
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||||
|
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||||
|
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(counter[0], 1)
|
||||||
|
|
||||||
|
# Test precedence
|
||||||
|
counter = [0]
|
||||||
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
|
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||||
|
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(counter[0], 2)
|
||||||
|
|
||||||
|
counter = [0]
|
||||||
|
a = torch.tensor(1.0, requires_grad=True)
|
||||||
|
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||||
|
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(counter[0], 1)
|
||||||
|
|
||||||
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
|
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
|
||||||
# Case 1: We have one tensor saved and its the input
|
# Case 1: We have one tensor saved and its the input
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
||||||
user_context_fns = kwargs.pop("context_fn", None)
|
user_context_fns = kwargs.pop("context_fn", None)
|
||||||
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
|
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
|
||||||
debug = kwargs.pop("debug", False)
|
debug = kwargs.pop("debug", False)
|
||||||
|
early_stop = kwargs.pop("early_stop", True)
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -103,6 +104,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
||||||
context_fns,
|
context_fns,
|
||||||
determinism_check,
|
determinism_check,
|
||||||
debug,
|
debug,
|
||||||
|
early_stop,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -347,6 +347,7 @@ def checkpoint(
|
||||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||||||
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
early_stop: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
r"""Checkpoint a model or part of the model.
|
r"""Checkpoint a model or part of the model.
|
||||||
|
|
@ -425,6 +426,9 @@ def checkpoint(
|
||||||
passed as the tuple. For example, in LSTM, if user passes
|
passed as the tuple. For example, in LSTM, if user passes
|
||||||
``(activation, hidden)``, :attr:`function` should correctly use the
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
||||||
first input as ``activation`` and the second input as ``hidden``
|
first input as ``activation`` and the second input as ``hidden``
|
||||||
|
args: tuple containing inputs to the :attr:`function`
|
||||||
|
|
||||||
|
Keyword args:
|
||||||
preserve_rng_state(bool, optional): Omit stashing and restoring
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||||||
the RNG state during each checkpoint. Note that under torch.compile,
|
the RNG state during each checkpoint. Note that under torch.compile,
|
||||||
this flag doesn't take effect and we always preserve RNG state.
|
this flag doesn't take effect and we always preserve RNG state.
|
||||||
|
|
@ -455,7 +459,11 @@ def checkpoint(
|
||||||
a trace of the operators ran during the original forward computation
|
a trace of the operators ran during the original forward computation
|
||||||
as well as the recomputation. This argument is only supported if
|
as well as the recomputation. This argument is only supported if
|
||||||
``use_reentrant=False``.
|
``use_reentrant=False``.
|
||||||
args: tuple containing inputs to the :attr:`function`
|
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
|
||||||
|
recomputation as soon as it has computed all needed Tensors. This
|
||||||
|
argument is ignored if ``use_reentrant=True``. Can be overridden
|
||||||
|
globally using :func:`set_checkpoint_early_stop` context manager.
|
||||||
|
Default: ``True``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output of running :attr:`function` on :attr:`*args`
|
Output of running :attr:`function` on :attr:`*args`
|
||||||
|
|
@ -488,7 +496,7 @@ def checkpoint(
|
||||||
return CheckpointFunction.apply(function, preserve, *args)
|
return CheckpointFunction.apply(function, preserve, *args)
|
||||||
else:
|
else:
|
||||||
gen = _checkpoint_without_reentrant_generator(
|
gen = _checkpoint_without_reentrant_generator(
|
||||||
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
|
function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
|
||||||
)
|
)
|
||||||
# Runs pre-forward logic
|
# Runs pre-forward logic
|
||||||
next(gen)
|
next(gen)
|
||||||
|
|
@ -731,7 +739,7 @@ def _internal_assert(cond):
|
||||||
# by holder=None. We skip over them. We still save x at (4) (since its holder
|
# by holder=None. We skip over them. We still save x at (4) (since its holder
|
||||||
# is still alive.)
|
# is still alive.)
|
||||||
|
|
||||||
_enable_checkpoint_early_stop = True
|
_enable_checkpoint_early_stop: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
@ -1448,6 +1456,7 @@ def _checkpoint_without_reentrant_generator(
|
||||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||||||
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
early_stop: bool = True,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
@ -1475,6 +1484,10 @@ def _checkpoint_without_reentrant_generator(
|
||||||
debug(bool, optional): If ``True``, error messages will also include
|
debug(bool, optional): If ``True``, error messages will also include
|
||||||
a trace of the operators ran during the original forward computation
|
a trace of the operators ran during the original forward computation
|
||||||
as well as the recomputation.
|
as well as the recomputation.
|
||||||
|
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
|
||||||
|
recomputation as soon as it has computed all needed Tensors. Can be
|
||||||
|
overridden globally using :func:`set_checkpoint_early_stop` context
|
||||||
|
manager. Default: ``True``.
|
||||||
*args: Arguments to pass in to the given ``function``.
|
*args: Arguments to pass in to the given ``function``.
|
||||||
**kwargs: Keyword arguments to pass into the given ``function``.
|
**kwargs: Keyword arguments to pass into the given ``function``.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1543,7 +1556,7 @@ def _checkpoint_without_reentrant_generator(
|
||||||
|
|
||||||
new_frame = _CheckpointFrame(
|
new_frame = _CheckpointFrame(
|
||||||
recompute_fn,
|
recompute_fn,
|
||||||
_enable_checkpoint_early_stop,
|
_enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop,
|
||||||
unpack_error_cb,
|
unpack_error_cb,
|
||||||
metadata_fn
|
metadata_fn
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user