Expose API to specify custom context manager for checkpoint (#96783)

Per [design](https://docs.google.com/document/d/1v-yqRqiWA6dIUOw5OpqFs2PqSQIbDEkwRPGk9FcYnxg/edit) we want (1) to allow the user to pass in a function that returns two context managers (2) a per-call API only for now, and (3) do not upstream selective checkpoint for the short term.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96783
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2023-03-14 19:15:42 -04:00 committed by PyTorch MergeBot
parent ac7329b323
commit f3db2a6341
2 changed files with 68 additions and 6 deletions

View File

@ -5608,6 +5608,39 @@ for shape in [(1,), ()]:
out = checkpoint(foo, x, y, z, use_reentrant=False)
out.sum().backward()
def test_checkpointing_without_reentrant_with_context_fn(self):
class VerboseTorchDispatchMode(TorchDispatchMode):
def __init__(self):
self.operators = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.operators.append(func.__name__)
return func(*args, **kwargs)
x = torch.tensor(1., requires_grad=True)
verbose_mode = VerboseTorchDispatchMode()
def context_fn():
return verbose_mode, contextlib.nullcontext()
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
self.assertEqual(verbose_mode.operators, ['sin.default'])
verbose_mode.operators = []
def context_fn():
return contextlib.nullcontext(), verbose_mode
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
out.backward()
self.assertEqual(
verbose_mode.operators,
['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default']
)
with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
out = checkpoint(lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn)
def test_access_saved_tensor_twice_without_recomputation_works(self):
count = [0]

View File

@ -2,7 +2,7 @@ import torch
import warnings
import weakref
from weakref import ReferenceType
from typing import Any, Iterable, List, Tuple, Dict, Optional, DefaultDict
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
from collections import defaultdict
import uuid
import contextlib
@ -10,7 +10,7 @@ import contextlib
__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
"check_backward_validity", "detach_variable", "get_device_states",
"set_device_states",
"set_device_states", "noop_context_fn"
]
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
@ -165,7 +165,17 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
def noop_context_fn():
return contextlib.nullcontext(), contextlib.nullcontext()
def checkpoint(
function,
*args,
use_reentrant: bool = True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
**kwargs
):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
@ -239,6 +249,10 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
keyword arguments input into the checkpointed function. Note that future
versions of PyTorch will default to ``use_reentrant=False``.
Default: ``True``
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
This argument is only supported if ``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function`
Returns:
@ -250,11 +264,14 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
if use_reentrant:
if context_fn is not noop_context_fn:
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
function,
preserve,
context_fn,
*args,
**kwargs,
)
@ -626,7 +643,13 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
# saving/restoring of global state is handled here.
def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
def _checkpoint_without_reentrant(
fn,
preserve_rng_state=True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
*args,
**kwargs
):
"""Checkpointining without re-entrant autograd
Args:
function: describes what to run in the forward pass of the model or
@ -637,9 +660,13 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
*args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``.
"""
forward_context, recompute_context = context_fn()
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
@ -669,7 +696,8 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
set_device_states(fwd_gpu_devices, fwd_gpu_states)
with torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**cpu_autocast_kwargs):
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)
new_frame = _CheckpointFrame(recompute_fn)
@ -680,7 +708,8 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
if new_frame.input_saver.grad_fn is None:
return fn(*args, **kwargs)
with _checkpoint_hook(new_frame):
with _checkpoint_hook(new_frame), \
forward_context:
ret = fn(*args, **kwargs)
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd: