mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Reland][Autograd/Checkpoint] Checkpoint implementation without reentrant autograd (#69508)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69508
Original Phabricator Diff: D32704467 (e032dae329)
Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented.
Original PR body:
Resubmission of https://github.com/pytorch/pytorch/pull/62964 with the
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.
Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).
As discussed in https://github.com/pytorch/pytorch/issues/65537, the tests that we need to add are:
- [x] Gradient hooks are called once
- [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn)
- [x] works for functions with arbitrary input/output objects
- [x] distributed tests (next PR)
Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`.
ghstack-source-id: 144948501
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D32902634
fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
This commit is contained in:
parent
3456c2cbc8
commit
049debd97d
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["module: autograd"]
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import io
|
||||
import math
|
||||
|
|
@ -31,7 +32,8 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
|
||||
slowTest, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
|
||||
disable_gc, gradcheck, gradgradcheck)
|
||||
disable_gc, gradcheck, gradgradcheck,
|
||||
parametrize, instantiate_parametrized_tests)
|
||||
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
|
||||
from torch.autograd.function import InplaceFunction
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
|
|
@ -4308,6 +4310,50 @@ for shape in [(1,), ()]:
|
|||
mean_combined = torch.stack(feat_combined).mean()
|
||||
mean_combined.backward()
|
||||
|
||||
@slowTest
|
||||
@parametrize("input_requires_grad", [True, False])
|
||||
def test_checkpointing_without_reentrant(self, input_requires_grad):
|
||||
"""
|
||||
Basic test for checkpoint without reentrant autograd.
|
||||
"""
|
||||
num_inp = 2000
|
||||
nz_inp = 10
|
||||
nz_out = 10
|
||||
nz_bottleneck = 1000
|
||||
|
||||
# small proxy network for some complex reasoning we want to do per input
|
||||
module = nn.Sequential(
|
||||
nn.Linear(nz_inp, nz_bottleneck),
|
||||
nn.ReLU(),
|
||||
nn.Linear(nz_bottleneck, nz_inp)
|
||||
)
|
||||
|
||||
# Run model with and without checkpointing and verify gradients are
|
||||
# equivalent, regardless of if inputs require grads or not.
|
||||
module_copy = deepcopy(module)
|
||||
|
||||
feat_combined = []
|
||||
feat_combined_no_checkpoint = []
|
||||
for r in range(num_inp):
|
||||
data_r = torch.empty(1, nz_inp)
|
||||
data_r.uniform_()
|
||||
data_r.requires_grad = input_requires_grad
|
||||
data_r_copy = data_r.clone()
|
||||
feat_r = checkpoint(module, data_r, use_reentrant=False)
|
||||
feat_combined.append(feat_r)
|
||||
feat_r_no_checkpoint = module_copy(data_r)
|
||||
feat_combined_no_checkpoint.append(feat_r_no_checkpoint)
|
||||
|
||||
|
||||
# compute mean as a proxy for some joint reasoning
|
||||
mean_combined = torch.stack(feat_combined).mean()
|
||||
mean_combined.backward()
|
||||
mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
|
||||
mean_combined_no_checkpoint.backward()
|
||||
|
||||
for checkpoint_param, param in zip(module.parameters(), module_copy.parameters()):
|
||||
self.assertEqual(checkpoint_param.grad, param.grad)
|
||||
|
||||
def test_checkpoint_valid_reset_on_error(self):
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
|
||||
|
|
@ -4318,6 +4364,156 @@ for shape in [(1,), ()]:
|
|||
c = checkpoint(torch.exp, a).sum()
|
||||
c.backward()
|
||||
|
||||
@parametrize("use_reentrant", [True, False])
|
||||
def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
|
||||
class NoGradModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 2, bias=False)
|
||||
self.lin2 = nn.Linear(2, 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
with torch.no_grad():
|
||||
return self.lin2(self.linear(x))
|
||||
|
||||
module = NoGradModule()
|
||||
|
||||
err_ctx = (
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"none of output has requires_grad=True"
|
||||
)
|
||||
if use_reentrant
|
||||
else contextlib.suppress()
|
||||
)
|
||||
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
for _ in range(3):
|
||||
with err_ctx:
|
||||
# out does not require grad
|
||||
out = checkpoint(module, a, use_reentrant=use_reentrant)
|
||||
# Make loss require grad, otherwise we would run into
|
||||
# "element 0 of tensors does not require grad and does not have a grad_fn"
|
||||
out += a
|
||||
out.sum().backward()
|
||||
|
||||
def test_checkpointing_without_reentrant_correct_grad(self):
|
||||
"""
|
||||
Verifies that correct gradients are calculated for checkpoint
|
||||
without reentrant autograd, for both backward() and autograd.grad().
|
||||
"""
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
|
||||
b = torch.exp(a).sum()
|
||||
b.backward()
|
||||
b_grad = a.grad
|
||||
|
||||
a.grad = None
|
||||
c = checkpoint(torch.exp, a, use_reentrant=False).sum()
|
||||
c.backward()
|
||||
c_grad = a.grad
|
||||
|
||||
a.grad = None
|
||||
d = checkpoint(torch.exp, a, use_reentrant=False).sum()
|
||||
d_grad, = torch.autograd.grad(d, (a,))
|
||||
|
||||
self.assertEqual(b_grad, c_grad)
|
||||
self.assertEqual(b_grad, d_grad)
|
||||
|
||||
def test_checkpointing_without_reentrant_dataparallel(self):
|
||||
"""
|
||||
Verifies gradient correctness when checkpoint without reentrant autograd
|
||||
is used in conjunction with DataParallel.
|
||||
"""
|
||||
class LinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 2, bias=False)
|
||||
|
||||
def forward(self, inp):
|
||||
return self.linear(inp)
|
||||
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
if torch.cuda.is_available():
|
||||
a = a.cuda()
|
||||
|
||||
model = LinearModule()
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
|
||||
b = deepcopy(model)(a).sum()
|
||||
b.backward()
|
||||
b_grad = a.grad
|
||||
|
||||
a.grad = None
|
||||
|
||||
module = torch.nn.DataParallel(deepcopy(model))
|
||||
c = checkpoint(module, a, use_reentrant=False).sum()
|
||||
c.backward()
|
||||
c_grad = a.grad
|
||||
|
||||
self.assertEqual(b_grad, c_grad)
|
||||
|
||||
def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
|
||||
"""
|
||||
Ensures that gradient hooks are only called once per tensor.
|
||||
"""
|
||||
w = torch.randn(10, 10, requires_grad=True)
|
||||
count = 0
|
||||
|
||||
def hook(grad):
|
||||
nonlocal count
|
||||
count += 1
|
||||
|
||||
w.register_hook(hook)
|
||||
x = torch.rand(10, 10, requires_grad=True)
|
||||
h = w * x # Using w outside the checkpoint
|
||||
out = checkpoint(lambda x: w * x, h, use_reentrant=False) # Using w inside the checkpoint
|
||||
|
||||
out.sum().backward()
|
||||
# should only call hook once
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
|
||||
"""
|
||||
Ensures checkpointing without reentrant autograd works with functions
|
||||
with arbitrary input/output structures.
|
||||
"""
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Linear(5, 5, bias=False)
|
||||
|
||||
def forward(self, dict_input):
|
||||
tensor = dict_input["tensor"]
|
||||
return {
|
||||
"result": self.layer(tensor)
|
||||
}
|
||||
|
||||
model_no_checkpoint = MyModel()
|
||||
model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)
|
||||
|
||||
inp = {
|
||||
"tensor": torch.randn(5, 5)
|
||||
}
|
||||
|
||||
out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()
|
||||
|
||||
out_checkpoint = checkpoint(
|
||||
model_checkpoint_without_reentrant,
|
||||
inp,
|
||||
use_reentrant=False
|
||||
)["result"].sum()
|
||||
|
||||
self.assertEqual(out_checkpoint, out_no_checkpoint)
|
||||
|
||||
out_no_checkpoint.backward()
|
||||
out_checkpoint.backward()
|
||||
|
||||
for param, checkpoint_param in zip(model_no_checkpoint.parameters(), model_checkpoint_without_reentrant.parameters()):
|
||||
self.assertEqual(param.grad, checkpoint_param.grad)
|
||||
|
||||
def test_callback_adds_callback(self):
|
||||
called = [0]
|
||||
|
||||
|
|
@ -9108,5 +9304,7 @@ instantiate_device_type_tests(
|
|||
except_for=None
|
||||
)
|
||||
|
||||
instantiate_parametrized_tests(TestAutograd)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Tuple
|
||||
from typing import Any, Iterable, List, Tuple, Union
|
||||
|
||||
|
||||
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
||||
|
|
@ -142,7 +142,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
|
||||
r"""Checkpoint a model or part of the model
|
||||
|
||||
Checkpointing works by trading compute for memory. Rather than storing all
|
||||
|
|
@ -165,10 +165,6 @@ def checkpoint(function, *args, **kwargs):
|
|||
consisting of Tensors, these Tensors nested in custom structures will not
|
||||
be considered as part of autograd.
|
||||
|
||||
.. warning::
|
||||
Checkpointing currently only supports :func:`torch.autograd.backward`
|
||||
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
||||
is not supported.
|
||||
|
||||
.. warning::
|
||||
If :attr:`function` invocation during backward does anything different
|
||||
|
|
@ -177,18 +173,30 @@ def checkpoint(function, *args, **kwargs):
|
|||
detected.
|
||||
|
||||
.. warning::
|
||||
If checkpointed segment contains tensors detached from the computational
|
||||
graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
|
||||
error. This is because `checkpoint` makes all the outputs require
|
||||
gradients which causes issues when a tensor is defined to have no
|
||||
gradient in the model. To circumvent this, detach the tensors outside of
|
||||
the `checkpoint` function.
|
||||
If ``use_reentrant=True`` is specified, then if the checkpointed segment
|
||||
contains tensors detached from the computational graph by `detach()` or
|
||||
`torch.no_grad()`, the backward pass will raise an error. This is
|
||||
because `checkpoint` makes all the outputs require gradients which
|
||||
causes issues when a tensor is defined to have no gradient in the model.
|
||||
To circumvent this, detach the tensors outside of the `checkpoint`
|
||||
function. Note that the checkpointed segment can contain tensors
|
||||
detached from the computational graph if ``use_reentrant=False`` is
|
||||
specified.
|
||||
|
||||
.. warning::
|
||||
At least one of the inputs needs to have :code:`requires_grad=True` if
|
||||
grads are needed for model inputs, otherwise the checkpointed part of the
|
||||
model won't have gradients. At least one of the outputs needs to have
|
||||
:code:`requires_grad=True` as well.
|
||||
If ``use_reentrant=True`` is specified, at least one of the inputs needs
|
||||
to have :code:`requires_grad=True` if grads are needed for model inputs,
|
||||
otherwise the checkpointed part of the model won't have gradients. At
|
||||
least one of the outputs needs to have :code:`requires_grad=True` as
|
||||
well. Note that this does not apply if ``use_reentrant=False`` is
|
||||
specified.
|
||||
|
||||
.. warning::
|
||||
If ``use_reentrant=True`` is specified, checkpointing currently only
|
||||
supports :func:`torch.autograd.backward` and only if its `inputs`
|
||||
argument is not passed. :func:`torch.autograd.grad`
|
||||
is not supported. If ``use_reentrant=False`` is specified, checkpointing
|
||||
will work with :func:`torch.autograd.grad`.
|
||||
|
||||
Args:
|
||||
function: describes what to run in the forward pass of the model or
|
||||
|
|
@ -198,6 +206,13 @@ def checkpoint(function, *args, **kwargs):
|
|||
first input as ``activation`` and the second input as ``hidden``
|
||||
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
use_reentrant(bool, optional, default=True): Use checkpointing
|
||||
implementation that requires re-entrant autograd.
|
||||
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
||||
implementation that does not require re-entrant autograd. This
|
||||
allows ``checkpoint`` to support additional functionality, such as
|
||||
working as expected with ``torch.autograd.grad``. Note that future
|
||||
versions of PyTorch will default to ``use_reentrant=False``.
|
||||
args: tuple containing inputs to the :attr:`function`
|
||||
|
||||
Returns:
|
||||
|
|
@ -208,7 +223,14 @@ def checkpoint(function, *args, **kwargs):
|
|||
if kwargs:
|
||||
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
||||
|
||||
return CheckpointFunction.apply(function, preserve, *args)
|
||||
if use_reentrant:
|
||||
return CheckpointFunction.apply(function, preserve, *args)
|
||||
else:
|
||||
return _checkpoint_without_reentrant(
|
||||
function,
|
||||
preserve,
|
||||
*args
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_sequential(functions, segments, input, **kwargs):
|
||||
|
|
@ -275,3 +297,78 @@ def checkpoint_sequential(functions, segments, input, **kwargs):
|
|||
input = checkpoint(run_function(start, end, functions), input,
|
||||
preserve_rng_state=preserve)
|
||||
return run_function(end + 1, len(functions) - 1, functions)(input)
|
||||
|
||||
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args):
|
||||
"""Checkpointining without re-entrant autograd
|
||||
Args:
|
||||
function: describes what to run in the forward pass of the model or
|
||||
part of the model. It should also know how to handle the inputs
|
||||
passed as the tuple. For example, in LSTM, if user passes
|
||||
``(activation, hidden)``, :attr:`function` should correctly use the
|
||||
first input as ``activation`` and the second input as ``hidden``
|
||||
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
*args: Arguments to pass in to the given ``function``.
|
||||
"""
|
||||
had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
|
||||
if preserve_rng_state:
|
||||
fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
# (If the user intends that the context is initialized later, within their
|
||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||
# we have no way to anticipate this will happen before we run the function.
|
||||
# If they do so, we raise an error.)
|
||||
had_cuda_in_fwd = False
|
||||
if torch.cuda._initialized:
|
||||
had_cuda_in_fwd = True
|
||||
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
|
||||
|
||||
storage: List[Union[torch.Tensor, None]] = []
|
||||
counter = 0
|
||||
|
||||
def pack(x):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
# TODO(varal7): Instead of returning indices, we can return things metadata (such as
|
||||
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
|
||||
return counter - 1
|
||||
|
||||
def unpack(x):
|
||||
if len(storage) == 0:
|
||||
|
||||
def inner_pack(inner):
|
||||
storage.append(inner)
|
||||
return None
|
||||
|
||||
def inner_unpack(packed):
|
||||
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
rng_devices = []
|
||||
if preserve_rng_state and had_cuda_in_fwd:
|
||||
rng_devices = fwd_gpu_devices
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
|
||||
if preserve_rng_state:
|
||||
torch.set_rng_state(fwd_cpu_state)
|
||||
if had_cuda_in_fwd:
|
||||
set_device_states(fwd_gpu_devices, fwd_gpu_states)
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(had_autocast_in_fwd):
|
||||
with torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
|
||||
return storage[x]
|
||||
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
output = function(*args)
|
||||
if torch.cuda._initialized and not had_cuda_in_fwd:
|
||||
# Cuda was not initialized before running the forward, so we didn't
|
||||
# stash the CUDA state.
|
||||
raise RuntimeError(
|
||||
"PyTorch's CUDA state was initialized in the forward pass "
|
||||
"of a Checkpoint, which is not allowed. Please open an issue "
|
||||
"if you need this feature.")
|
||||
|
||||
return output
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user