[Reland][Autograd/Checkpoint] Checkpoint implementation without reentrant autograd (#69508)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69508

Original Phabricator Diff: D32704467 (e032dae329)

Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented.

Original PR body:

Resubmission of https://github.com/pytorch/pytorch/pull/62964 with the
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.

Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).

As discussed in https://github.com/pytorch/pytorch/issues/65537, the tests that we need to add are:

- [x] Gradient hooks are called once
- [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn)
- [x] works for functions with arbitrary input/output objects
- [x] distributed tests (next PR)

Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`.
ghstack-source-id: 144948501

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32902634

fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
This commit is contained in:
Rohan Varma 2021-12-07 16:26:36 -08:00 committed by Facebook GitHub Bot
parent 3456c2cbc8
commit 049debd97d
2 changed files with 313 additions and 18 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: autograd"]
import contextlib
import gc
import io
import math
@ -31,7 +32,8 @@ from torch.testing import make_tensor
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
slowTest, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
disable_gc, gradcheck, gradgradcheck)
disable_gc, gradcheck, gradgradcheck,
parametrize, instantiate_parametrized_tests)
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
import torch.autograd.forward_ad as fwAD
@ -4308,6 +4310,50 @@ for shape in [(1,), ()]:
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()
@slowTest
@parametrize("input_requires_grad", [True, False])
def test_checkpointing_without_reentrant(self, input_requires_grad):
"""
Basic test for checkpoint without reentrant autograd.
"""
num_inp = 2000
nz_inp = 10
nz_out = 10
nz_bottleneck = 1000
# small proxy network for some complex reasoning we want to do per input
module = nn.Sequential(
nn.Linear(nz_inp, nz_bottleneck),
nn.ReLU(),
nn.Linear(nz_bottleneck, nz_inp)
)
# Run model with and without checkpointing and verify gradients are
# equivalent, regardless of if inputs require grads or not.
module_copy = deepcopy(module)
feat_combined = []
feat_combined_no_checkpoint = []
for r in range(num_inp):
data_r = torch.empty(1, nz_inp)
data_r.uniform_()
data_r.requires_grad = input_requires_grad
data_r_copy = data_r.clone()
feat_r = checkpoint(module, data_r, use_reentrant=False)
feat_combined.append(feat_r)
feat_r_no_checkpoint = module_copy(data_r)
feat_combined_no_checkpoint.append(feat_r_no_checkpoint)
# compute mean as a proxy for some joint reasoning
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()
mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
mean_combined_no_checkpoint.backward()
for checkpoint_param, param in zip(module.parameters(), module_copy.parameters()):
self.assertEqual(checkpoint_param.grad, param.grad)
def test_checkpoint_valid_reset_on_error(self):
a = torch.randn(2, 2, requires_grad=True)
@ -4318,6 +4364,156 @@ for shape in [(1,), ()]:
c = checkpoint(torch.exp, a).sum()
c.backward()
@parametrize("use_reentrant", [True, False])
def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
class NoGradModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2, bias=False)
self.lin2 = nn.Linear(2, 2, bias=False)
def forward(self, x):
with torch.no_grad():
return self.lin2(self.linear(x))
module = NoGradModule()
err_ctx = (
self.assertRaisesRegex(
RuntimeError,
"none of output has requires_grad=True"
)
if use_reentrant
else contextlib.suppress()
)
a = torch.randn(2, 2, requires_grad=True)
for _ in range(3):
with err_ctx:
# out does not require grad
out = checkpoint(module, a, use_reentrant=use_reentrant)
# Make loss require grad, otherwise we would run into
# "element 0 of tensors does not require grad and does not have a grad_fn"
out += a
out.sum().backward()
def test_checkpointing_without_reentrant_correct_grad(self):
"""
Verifies that correct gradients are calculated for checkpoint
without reentrant autograd, for both backward() and autograd.grad().
"""
a = torch.randn(2, 2, requires_grad=True)
b = torch.exp(a).sum()
b.backward()
b_grad = a.grad
a.grad = None
c = checkpoint(torch.exp, a, use_reentrant=False).sum()
c.backward()
c_grad = a.grad
a.grad = None
d = checkpoint(torch.exp, a, use_reentrant=False).sum()
d_grad, = torch.autograd.grad(d, (a,))
self.assertEqual(b_grad, c_grad)
self.assertEqual(b_grad, d_grad)
def test_checkpointing_without_reentrant_dataparallel(self):
"""
Verifies gradient correctness when checkpoint without reentrant autograd
is used in conjunction with DataParallel.
"""
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2, bias=False)
def forward(self, inp):
return self.linear(inp)
a = torch.randn(2, 2, requires_grad=True)
if torch.cuda.is_available():
a = a.cuda()
model = LinearModule()
if torch.cuda.is_available():
model = model.cuda()
b = deepcopy(model)(a).sum()
b.backward()
b_grad = a.grad
a.grad = None
module = torch.nn.DataParallel(deepcopy(model))
c = checkpoint(module, a, use_reentrant=False).sum()
c.backward()
c_grad = a.grad
self.assertEqual(b_grad, c_grad)
def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
"""
Ensures that gradient hooks are only called once per tensor.
"""
w = torch.randn(10, 10, requires_grad=True)
count = 0
def hook(grad):
nonlocal count
count += 1
w.register_hook(hook)
x = torch.rand(10, 10, requires_grad=True)
h = w * x # Using w outside the checkpoint
out = checkpoint(lambda x: w * x, h, use_reentrant=False) # Using w inside the checkpoint
out.sum().backward()
# should only call hook once
self.assertEqual(count, 1)
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
"""
Ensures checkpointing without reentrant autograd works with functions
with arbitrary input/output structures.
"""
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(5, 5, bias=False)
def forward(self, dict_input):
tensor = dict_input["tensor"]
return {
"result": self.layer(tensor)
}
model_no_checkpoint = MyModel()
model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)
inp = {
"tensor": torch.randn(5, 5)
}
out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()
out_checkpoint = checkpoint(
model_checkpoint_without_reentrant,
inp,
use_reentrant=False
)["result"].sum()
self.assertEqual(out_checkpoint, out_no_checkpoint)
out_no_checkpoint.backward()
out_checkpoint.backward()
for param, checkpoint_param in zip(model_no_checkpoint.parameters(), model_checkpoint_without_reentrant.parameters()):
self.assertEqual(param.grad, checkpoint_param.grad)
def test_callback_adds_callback(self):
called = [0]
@ -9108,5 +9304,7 @@ instantiate_device_type_tests(
except_for=None
)
instantiate_parametrized_tests(TestAutograd)
if __name__ == '__main__':
run_tests()

View File

@ -1,6 +1,6 @@
import torch
import warnings
from typing import Any, Iterable, List, Tuple
from typing import Any, Iterable, List, Tuple, Union
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
@ -142,7 +142,7 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads
def checkpoint(function, *args, **kwargs):
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
@ -165,10 +165,6 @@ def checkpoint(function, *args, **kwargs):
consisting of Tensors, these Tensors nested in custom structures will not
be considered as part of autograd.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning::
If :attr:`function` invocation during backward does anything different
@ -177,18 +173,30 @@ def checkpoint(function, *args, **kwargs):
detected.
.. warning::
If checkpointed segment contains tensors detached from the computational
graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
error. This is because `checkpoint` makes all the outputs require
gradients which causes issues when a tensor is defined to have no
gradient in the model. To circumvent this, detach the tensors outside of
the `checkpoint` function.
If ``use_reentrant=True`` is specified, then if the checkpointed segment
contains tensors detached from the computational graph by `detach()` or
`torch.no_grad()`, the backward pass will raise an error. This is
because `checkpoint` makes all the outputs require gradients which
causes issues when a tensor is defined to have no gradient in the model.
To circumvent this, detach the tensors outside of the `checkpoint`
function. Note that the checkpointed segment can contain tensors
detached from the computational graph if ``use_reentrant=False`` is
specified.
.. warning::
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients. At least one of the outputs needs to have
:code:`requires_grad=True` as well.
If ``use_reentrant=True`` is specified, at least one of the inputs needs
to have :code:`requires_grad=True` if grads are needed for model inputs,
otherwise the checkpointed part of the model won't have gradients. At
least one of the outputs needs to have :code:`requires_grad=True` as
well. Note that this does not apply if ``use_reentrant=False`` is
specified.
.. warning::
If ``use_reentrant=True`` is specified, checkpointing currently only
supports :func:`torch.autograd.backward` and only if its `inputs`
argument is not passed. :func:`torch.autograd.grad`
is not supported. If ``use_reentrant=False`` is specified, checkpointing
will work with :func:`torch.autograd.grad`.
Args:
function: describes what to run in the forward pass of the model or
@ -198,6 +206,13 @@ def checkpoint(function, *args, **kwargs):
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
use_reentrant(bool, optional, default=True): Use checkpointing
implementation that requires re-entrant autograd.
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
implementation that does not require re-entrant autograd. This
allows ``checkpoint`` to support additional functionality, such as
working as expected with ``torch.autograd.grad``. Note that future
versions of PyTorch will default to ``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function`
Returns:
@ -208,7 +223,14 @@ def checkpoint(function, *args, **kwargs):
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
return CheckpointFunction.apply(function, preserve, *args)
if use_reentrant:
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
function,
preserve,
*args
)
def checkpoint_sequential(functions, segments, input, **kwargs):
@ -275,3 +297,78 @@ def checkpoint_sequential(functions, segments, input, **kwargs):
input = checkpoint(run_function(start, end, functions), input,
preserve_rng_state=preserve)
return run_function(end + 1, len(functions) - 1, functions)(input)
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args):
"""Checkpointining without re-entrant autograd
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
*args: Arguments to pass in to the given ``function``.
"""
had_autocast_in_fwd = torch.is_autocast_enabled()
if preserve_rng_state:
fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.
# If they do so, we raise an error.)
had_cuda_in_fwd = False
if torch.cuda._initialized:
had_cuda_in_fwd = True
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
storage: List[Union[torch.Tensor, None]] = []
counter = 0
def pack(x):
nonlocal counter
counter += 1
# TODO(varal7): Instead of returning indices, we can return things metadata (such as
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
return counter - 1
def unpack(x):
if len(storage) == 0:
def inner_pack(inner):
storage.append(inner)
return None
def inner_unpack(packed):
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if preserve_rng_state and had_cuda_in_fwd:
rng_devices = fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
if preserve_rng_state:
torch.set_rng_state(fwd_cpu_state)
if had_cuda_in_fwd:
set_device_states(fwd_gpu_devices, fwd_gpu_states)
with torch.enable_grad(), torch.cuda.amp.autocast(had_autocast_in_fwd):
with torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
return storage[x]
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
output = function(*args)
if torch.cuda._initialized and not had_cuda_in_fwd:
# Cuda was not initialized before running the forward, so we didn't
# stash the CUDA state.
raise RuntimeError(
"PyTorch's CUDA state was initialized in the forward pass "
"of a Checkpoint, which is not allowed. Please open an issue "
"if you need this feature.")
return output