mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Expose API to specify custom context manager for checkpoint (#96783)
Per [design](https://docs.google.com/document/d/1v-yqRqiWA6dIUOw5OpqFs2PqSQIbDEkwRPGk9FcYnxg/edit) we want (1) to allow the user to pass in a function that returns two context managers (2) a per-call API only for now, and (3) do not upstream selective checkpoint for the short term. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96783 Approved by: https://github.com/albanD
This commit is contained in:
parent
ac7329b323
commit
f3db2a6341
|
|
@ -5608,6 +5608,39 @@ for shape in [(1,), ()]:
|
|||
out = checkpoint(foo, x, y, z, use_reentrant=False)
|
||||
out.sum().backward()
|
||||
|
||||
def test_checkpointing_without_reentrant_with_context_fn(self):
|
||||
class VerboseTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
self.operators = []
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
self.operators.append(func.__name__)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = torch.tensor(1., requires_grad=True)
|
||||
verbose_mode = VerboseTorchDispatchMode()
|
||||
|
||||
def context_fn():
|
||||
return verbose_mode, contextlib.nullcontext()
|
||||
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
|
||||
self.assertEqual(verbose_mode.operators, ['sin.default'])
|
||||
|
||||
verbose_mode.operators = []
|
||||
|
||||
def context_fn():
|
||||
return contextlib.nullcontext(), verbose_mode
|
||||
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
|
||||
out.backward()
|
||||
self.assertEqual(
|
||||
verbose_mode.operators,
|
||||
['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default']
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
|
||||
out = checkpoint(lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn)
|
||||
|
||||
def test_access_saved_tensor_twice_without_recomputation_works(self):
|
||||
count = [0]
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import warnings
|
||||
import weakref
|
||||
from weakref import ReferenceType
|
||||
from typing import Any, Iterable, List, Tuple, Dict, Optional, DefaultDict
|
||||
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
|
||||
from collections import defaultdict
|
||||
import uuid
|
||||
import contextlib
|
||||
|
|
@ -10,7 +10,7 @@ import contextlib
|
|||
__all__ = [
|
||||
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
|
||||
"check_backward_validity", "detach_variable", "get_device_states",
|
||||
"set_device_states",
|
||||
"set_device_states", "noop_context_fn"
|
||||
]
|
||||
|
||||
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
||||
|
|
@ -165,7 +165,17 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
|
||||
def noop_context_fn():
|
||||
return contextlib.nullcontext(), contextlib.nullcontext()
|
||||
|
||||
|
||||
def checkpoint(
|
||||
function,
|
||||
*args,
|
||||
use_reentrant: bool = True,
|
||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||||
**kwargs
|
||||
):
|
||||
r"""Checkpoint a model or part of the model
|
||||
|
||||
Checkpointing works by trading compute for memory. Rather than storing all
|
||||
|
|
@ -239,6 +249,10 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
|
|||
keyword arguments input into the checkpointed function. Note that future
|
||||
versions of PyTorch will default to ``use_reentrant=False``.
|
||||
Default: ``True``
|
||||
context_fn(Callable, optional): A callable returning a tuple of two
|
||||
context managers. The function and its recomputation will be run
|
||||
under the first and second context managers respectively.
|
||||
This argument is only supported if ``use_reentrant=False``.
|
||||
args: tuple containing inputs to the :attr:`function`
|
||||
|
||||
Returns:
|
||||
|
|
@ -250,11 +264,14 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
|
|||
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
||||
|
||||
if use_reentrant:
|
||||
if context_fn is not noop_context_fn:
|
||||
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
|
||||
return CheckpointFunction.apply(function, preserve, *args)
|
||||
else:
|
||||
return _checkpoint_without_reentrant(
|
||||
function,
|
||||
preserve,
|
||||
context_fn,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -626,7 +643,13 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
|||
|
||||
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
|
||||
# saving/restoring of global state is handled here.
|
||||
def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
|
||||
def _checkpoint_without_reentrant(
|
||||
fn,
|
||||
preserve_rng_state=True,
|
||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Checkpointining without re-entrant autograd
|
||||
Args:
|
||||
function: describes what to run in the forward pass of the model or
|
||||
|
|
@ -637,9 +660,13 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
|
|||
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
Default: ``True``
|
||||
context_fn(Callable, optional): A callable returning a tuple of two
|
||||
context managers. The function and its recomputation will be run
|
||||
under the first and second context managers respectively.
|
||||
*args: Arguments to pass in to the given ``function``.
|
||||
**kwargs: Keyword arguments to pass into the given ``function``.
|
||||
"""
|
||||
forward_context, recompute_context = context_fn()
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||
|
||||
|
|
@ -669,7 +696,8 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
|
|||
set_device_states(fwd_gpu_devices, fwd_gpu_states)
|
||||
|
||||
with torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
|
||||
torch.cpu.amp.autocast(**cpu_autocast_kwargs):
|
||||
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||||
recompute_context:
|
||||
fn(*args, **kwargs)
|
||||
|
||||
new_frame = _CheckpointFrame(recompute_fn)
|
||||
|
|
@ -680,7 +708,8 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
|
|||
if new_frame.input_saver.grad_fn is None:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
with _checkpoint_hook(new_frame):
|
||||
with _checkpoint_hook(new_frame), \
|
||||
forward_context:
|
||||
ret = fn(*args, **kwargs)
|
||||
|
||||
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user