mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Per title Differential Revision: [D39714855](https://our.internmc.facebook.com/intern/diff/D39714855/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/85449 Approved by: https://github.com/awgu
310 lines
12 KiB
Python
310 lines
12 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
from copy import deepcopy
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
checkpoint_wrapper,
|
|
apply_activation_checkpointing,
|
|
CheckpointWrapper,
|
|
CheckpointImpl
|
|
)
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
import unittest
|
|
|
|
class CheckpointWrapperTest(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
def test_load_activation_checkpointed_module(self):
|
|
lin = nn.Linear(10, 10, bias=False)
|
|
lin = checkpoint_wrapper(
|
|
lin,
|
|
checkpoint_fn=checkpoint,
|
|
# checkpoint kwargs
|
|
use_reentrant=True,
|
|
preserve_rng_state=False,
|
|
)
|
|
state_dict = deepcopy(lin.state_dict())
|
|
# Load into non-checkpoint wrapped linear module
|
|
lin_new = nn.Linear(10, 10, bias=False)
|
|
lin_new.load_state_dict(state_dict)
|
|
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
|
|
self.assertEqual(p1, p2)
|
|
self.assertTrue(torch.allclose(p1, p2))
|
|
|
|
# Load non-checkpoint wrapped module into checkpoint wrapped one
|
|
# Make params different
|
|
for p in lin_new.parameters():
|
|
with torch.no_grad():
|
|
p.add_(0.5)
|
|
|
|
state_dict = deepcopy(lin_new.state_dict())
|
|
# Verify checkpoint wrapped linear can load unwrapped linear
|
|
lin.load_state_dict(state_dict)
|
|
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
|
|
self.assertEqual(p1, p2)
|
|
|
|
def test_checkpoint_wrapper_kwarg_support(self):
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = nn.Linear(10, 10)
|
|
|
|
def forward(self, a, b, c=None, d=None, **kwargs):
|
|
return (
|
|
self.lin(a),
|
|
self.lin(b),
|
|
self.lin(c),
|
|
self.lin(d)
|
|
)
|
|
|
|
|
|
for wrapper in [
|
|
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
|
|
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
|
|
partial(checkpoint_wrapper, offload_to_cpu=True),
|
|
]:
|
|
with self.subTest(wrapper=wrapper):
|
|
model = wrapper(MyModel())
|
|
self.assertTrue(isinstance(model, CheckpointWrapper))
|
|
# Verify kwargs can be passed in
|
|
inp = torch.ones(4, 10, requires_grad=True)
|
|
out = model(inp, inp, c=inp, d=inp, e=inp, f=inp)
|
|
self.assertTrue(isinstance(out, tuple))
|
|
self.assertEqual(4, len(out))
|
|
# Without kwargs should have equivalent gradient requirements.
|
|
out_no_kwarg = model(inp, inp, inp, inp)
|
|
for t1, t2 in zip(out_no_kwarg, out):
|
|
self.assertEqual(t1, t2)
|
|
self.assertEqual(t1.requires_grad, t2.requires_grad)
|
|
|
|
# Test model that enforces kwarg inputs
|
|
class ModelEnforceKwarg(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = nn.Linear(10, 10)
|
|
|
|
def forward(self, *, a=None, b=None):
|
|
return (self.lin(a), self.lin(b))
|
|
|
|
model = checkpoint_wrapper(
|
|
ModelEnforceKwarg(), checkpoint_impl=CheckpointImpl.REENTRANT
|
|
)
|
|
|
|
inp = torch.ones(4, 10, requires_grad=True)
|
|
out = model(a=inp, b=inp)
|
|
self.assertEqual(2, len(out))
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
|
def test_checkpoint_wrapper_parity(self):
|
|
"""
|
|
Tests that using checkpoint_wrapper or the functional
|
|
torch.utils.checkpoint (with the same reentrant config)
|
|
results in the same maximum memory usage, i.e. they are
|
|
equivalent memory usage wise.
|
|
"""
|
|
class Model(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n: int,
|
|
use_cp: bool,
|
|
use_wrapper: bool = False,
|
|
use_reentrant: bool = True
|
|
):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList()
|
|
self.n = n
|
|
self.use_cp = use_cp
|
|
self.use_wrapper = use_wrapper
|
|
self.use_reentrant = use_reentrant
|
|
wrp = partial(
|
|
checkpoint_wrapper,
|
|
checkpoint_impl=CheckpointImpl.REENTRANT if use_reentrant else CheckpointImpl.NO_REENTRANT
|
|
)
|
|
for i in range(self.n):
|
|
l = nn.Sequential(nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256))
|
|
use_checkpoint_wrapper = self.use_wrapper
|
|
if use_checkpoint_wrapper:
|
|
l = wrp(l)
|
|
self.layers.append(l)
|
|
|
|
def forward(self, x):
|
|
for i in range(self.n):
|
|
if (
|
|
self.use_wrapper or
|
|
not self.use_cp
|
|
):
|
|
x = self.layers[i](x)
|
|
else:
|
|
x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant)
|
|
return x
|
|
|
|
def test(use_checkpointing, use_wrapper, use_reentrant):
|
|
a = Model(8, use_checkpointing, use_wrapper=use_wrapper, use_reentrant=use_reentrant).cuda()
|
|
x = torch.randn(10000, 256, requires_grad=True).cuda()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
loss = a(x).sum()
|
|
loss.backward()
|
|
return torch.cuda.max_memory_allocated()
|
|
|
|
functional_no_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=False)
|
|
wrapper_no_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=False)
|
|
self.assertEqual(functional_no_reentrant, wrapper_no_reentrant)
|
|
|
|
functional_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=True)
|
|
wrapper_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=True)
|
|
self.assertEqual(functional_reentrant, wrapper_reentrant)
|
|
|
|
def test_forward_missing_attributes(self):
|
|
lin = nn.Linear(1, 1)
|
|
m = nn.Sequential(lin, lin)
|
|
wrapped = CheckpointWrapper(m)
|
|
# Test indexing is forwarded
|
|
self.assertEqual(wrapped[0], lin)
|
|
# Test missing attributes are forwarded.
|
|
m._foo = 'bar'
|
|
self.assertEqual(wrapped._foo, 'bar')
|
|
|
|
def test_apply_activation_checkpointing(self):
|
|
"""
|
|
Ensures that `apply_activation_checkpointing` can be used
|
|
to swap modules for their checkpoint-wrapped counterparts given
|
|
a model.
|
|
"""
|
|
class LinearWithBatchNorm(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = nn.Linear(10, 10)
|
|
self.bn = nn.BatchNorm1d(10)
|
|
self.nested_linear = nn.Sequential(nn.Linear(10, 10))
|
|
|
|
def forward(self, x):
|
|
return self.bn(self.nested_linear(self.lin(x)))
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
LinearWithBatchNorm(), LinearWithBatchNorm(), LinearWithBatchNorm()
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.seq(x)
|
|
|
|
|
|
def check_fn(l):
|
|
return isinstance(l, nn.Linear)
|
|
|
|
n_linear = None
|
|
|
|
for wrapper in [
|
|
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
|
|
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
|
|
]:
|
|
model = MyModel()
|
|
if n_linear is None:
|
|
n_linear = sum(
|
|
1 if isinstance(x, nn.Linear) else 0 for x in model.modules()
|
|
)
|
|
|
|
with self.subTest(wrapper=wrapper):
|
|
apply_activation_checkpointing(
|
|
model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn
|
|
)
|
|
n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules())
|
|
n_checkpointed = sum(1 if isinstance(x, CheckpointWrapper) else 0 for x in model.modules())
|
|
self.assertEqual(n_checkpointed, n_linear_wrapped)
|
|
self.assertEqual(n_linear, n_linear_wrapped)
|
|
for j in range(3):
|
|
self.assertTrue(isinstance(model.seq[j].lin, CheckpointWrapper))
|
|
self.assertTrue(isinstance(model.seq[j].nested_linear[0], CheckpointWrapper))
|
|
|
|
inp = torch.randn(4, 10, requires_grad=True)
|
|
for i in range(6):
|
|
# Kwarg input
|
|
loss = model(x=inp).sum()
|
|
self.assertTrue(loss.requires_grad)
|
|
loss.backward()
|
|
# ensure checkpointed part of model has gradients
|
|
for j in range(3):
|
|
weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
|
|
bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
|
|
weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight
|
|
bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias
|
|
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
|
|
self.assertTrue(param.requires_grad)
|
|
self.assertFalse(param.grad is None)
|
|
|
|
def test_fqn(self):
|
|
lin = nn.Linear(10, 10, bias=False)
|
|
lin = checkpoint_wrapper(lin)
|
|
state_dict = lin.state_dict()
|
|
for fqn, _ in lin.named_parameters():
|
|
self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
|
def test_checkpoint_wrapper_cpu_offload(self):
|
|
model = nn.Sequential(
|
|
nn.Linear(10, 10),
|
|
nn.Linear(10, 10),
|
|
nn.Linear(10, 10),
|
|
).cuda()
|
|
|
|
# Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for
|
|
# testing, otherwise the tensor access during the DFS will cause orig
|
|
# unpack to run, transferring the tensor back to GPU.
|
|
def patched_init(saved_tensor_hook_obj, pack_hook, _):
|
|
saved_tensor_hook_obj.pack_hook = pack_hook
|
|
|
|
def testing_cpu_offload_unpack_hook(packed):
|
|
_, tensor = packed
|
|
return tensor
|
|
|
|
saved_tensor_hook_obj.unpack_hook = testing_cpu_offload_unpack_hook
|
|
|
|
orig_init = torch.autograd.graph.saved_tensors_hooks.__init__
|
|
torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init
|
|
|
|
model = checkpoint_wrapper(model, offload_to_cpu=True)
|
|
|
|
inp = torch.randn(3, 10, device='cuda')
|
|
loss = model(inp).sum()
|
|
|
|
# All autograd saved tensors should be offloaded to CPU.
|
|
offload_verified = False
|
|
|
|
def dfs(grad_fn):
|
|
for e in dir(grad_fn):
|
|
if not e.startswith('_saved_'):
|
|
continue
|
|
|
|
saved = getattr(grad_fn, e)
|
|
if isinstance(saved, torch.Tensor):
|
|
self.assertEqual(torch.device("cpu"), saved.device)
|
|
nonlocal offload_verified
|
|
offload_verified = True
|
|
|
|
if hasattr(grad_fn, 'next_functions'):
|
|
for next_grad_fn, _ in grad_fn.next_functions:
|
|
dfs(next_grad_fn)
|
|
|
|
dfs(loss.grad_fn)
|
|
|
|
self.assertTrue(offload_verified)
|
|
|
|
torch.autograd.graph.saved_tensors_hooks.__init__ = orig_init
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|