pytorch/test/distributed/fsdp/test_checkpoint_wrapper.py
2022-09-23 21:15:08 +00:00

310 lines
12 KiB
Python

# Owner(s): ["oncall: distributed"]
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
apply_activation_checkpointing,
CheckpointWrapper,
CheckpointImpl
)
from torch.utils.checkpoint import checkpoint
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
import unittest
class CheckpointWrapperTest(TestCase):
def setUp(self):
super().setUp()
def test_load_activation_checkpointed_module(self):
lin = nn.Linear(10, 10, bias=False)
lin = checkpoint_wrapper(
lin,
checkpoint_fn=checkpoint,
# checkpoint kwargs
use_reentrant=True,
preserve_rng_state=False,
)
state_dict = deepcopy(lin.state_dict())
# Load into non-checkpoint wrapped linear module
lin_new = nn.Linear(10, 10, bias=False)
lin_new.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)
self.assertTrue(torch.allclose(p1, p2))
# Load non-checkpoint wrapped module into checkpoint wrapped one
# Make params different
for p in lin_new.parameters():
with torch.no_grad():
p.add_(0.5)
state_dict = deepcopy(lin_new.state_dict())
# Verify checkpoint wrapped linear can load unwrapped linear
lin.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)
def test_checkpoint_wrapper_kwarg_support(self):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 10)
def forward(self, a, b, c=None, d=None, **kwargs):
return (
self.lin(a),
self.lin(b),
self.lin(c),
self.lin(d)
)
for wrapper in [
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
partial(checkpoint_wrapper, offload_to_cpu=True),
]:
with self.subTest(wrapper=wrapper):
model = wrapper(MyModel())
self.assertTrue(isinstance(model, CheckpointWrapper))
# Verify kwargs can be passed in
inp = torch.ones(4, 10, requires_grad=True)
out = model(inp, inp, c=inp, d=inp, e=inp, f=inp)
self.assertTrue(isinstance(out, tuple))
self.assertEqual(4, len(out))
# Without kwargs should have equivalent gradient requirements.
out_no_kwarg = model(inp, inp, inp, inp)
for t1, t2 in zip(out_no_kwarg, out):
self.assertEqual(t1, t2)
self.assertEqual(t1.requires_grad, t2.requires_grad)
# Test model that enforces kwarg inputs
class ModelEnforceKwarg(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 10)
def forward(self, *, a=None, b=None):
return (self.lin(a), self.lin(b))
model = checkpoint_wrapper(
ModelEnforceKwarg(), checkpoint_impl=CheckpointImpl.REENTRANT
)
inp = torch.ones(4, 10, requires_grad=True)
out = model(a=inp, b=inp)
self.assertEqual(2, len(out))
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
def test_checkpoint_wrapper_parity(self):
"""
Tests that using checkpoint_wrapper or the functional
torch.utils.checkpoint (with the same reentrant config)
results in the same maximum memory usage, i.e. they are
equivalent memory usage wise.
"""
class Model(nn.Module):
def __init__(
self,
n: int,
use_cp: bool,
use_wrapper: bool = False,
use_reentrant: bool = True
):
super().__init__()
self.layers = nn.ModuleList()
self.n = n
self.use_cp = use_cp
self.use_wrapper = use_wrapper
self.use_reentrant = use_reentrant
wrp = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.REENTRANT if use_reentrant else CheckpointImpl.NO_REENTRANT
)
for i in range(self.n):
l = nn.Sequential(nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256))
use_checkpoint_wrapper = self.use_wrapper
if use_checkpoint_wrapper:
l = wrp(l)
self.layers.append(l)
def forward(self, x):
for i in range(self.n):
if (
self.use_wrapper or
not self.use_cp
):
x = self.layers[i](x)
else:
x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant)
return x
def test(use_checkpointing, use_wrapper, use_reentrant):
a = Model(8, use_checkpointing, use_wrapper=use_wrapper, use_reentrant=use_reentrant).cuda()
x = torch.randn(10000, 256, requires_grad=True).cuda()
torch.cuda.reset_peak_memory_stats()
loss = a(x).sum()
loss.backward()
return torch.cuda.max_memory_allocated()
functional_no_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=False)
wrapper_no_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=False)
self.assertEqual(functional_no_reentrant, wrapper_no_reentrant)
functional_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=True)
wrapper_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=True)
self.assertEqual(functional_reentrant, wrapper_reentrant)
def test_forward_missing_attributes(self):
lin = nn.Linear(1, 1)
m = nn.Sequential(lin, lin)
wrapped = CheckpointWrapper(m)
# Test indexing is forwarded
self.assertEqual(wrapped[0], lin)
# Test missing attributes are forwarded.
m._foo = 'bar'
self.assertEqual(wrapped._foo, 'bar')
def test_apply_activation_checkpointing(self):
"""
Ensures that `apply_activation_checkpointing` can be used
to swap modules for their checkpoint-wrapped counterparts given
a model.
"""
class LinearWithBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 10)
self.bn = nn.BatchNorm1d(10)
self.nested_linear = nn.Sequential(nn.Linear(10, 10))
def forward(self, x):
return self.bn(self.nested_linear(self.lin(x)))
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
LinearWithBatchNorm(), LinearWithBatchNorm(), LinearWithBatchNorm()
)
def forward(self, x):
return self.seq(x)
def check_fn(l):
return isinstance(l, nn.Linear)
n_linear = None
for wrapper in [
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
]:
model = MyModel()
if n_linear is None:
n_linear = sum(
1 if isinstance(x, nn.Linear) else 0 for x in model.modules()
)
with self.subTest(wrapper=wrapper):
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn
)
n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules())
n_checkpointed = sum(1 if isinstance(x, CheckpointWrapper) else 0 for x in model.modules())
self.assertEqual(n_checkpointed, n_linear_wrapped)
self.assertEqual(n_linear, n_linear_wrapped)
for j in range(3):
self.assertTrue(isinstance(model.seq[j].lin, CheckpointWrapper))
self.assertTrue(isinstance(model.seq[j].nested_linear[0], CheckpointWrapper))
inp = torch.randn(4, 10, requires_grad=True)
for i in range(6):
# Kwarg input
loss = model(x=inp).sum()
self.assertTrue(loss.requires_grad)
loss.backward()
# ensure checkpointed part of model has gradients
for j in range(3):
weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight
bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
self.assertTrue(param.requires_grad)
self.assertFalse(param.grad is None)
def test_fqn(self):
lin = nn.Linear(10, 10, bias=False)
lin = checkpoint_wrapper(lin)
state_dict = lin.state_dict()
for fqn, _ in lin.named_parameters():
self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
def test_checkpoint_wrapper_cpu_offload(self):
model = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
).cuda()
# Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for
# testing, otherwise the tensor access during the DFS will cause orig
# unpack to run, transferring the tensor back to GPU.
def patched_init(saved_tensor_hook_obj, pack_hook, _):
saved_tensor_hook_obj.pack_hook = pack_hook
def testing_cpu_offload_unpack_hook(packed):
_, tensor = packed
return tensor
saved_tensor_hook_obj.unpack_hook = testing_cpu_offload_unpack_hook
orig_init = torch.autograd.graph.saved_tensors_hooks.__init__
torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init
model = checkpoint_wrapper(model, offload_to_cpu=True)
inp = torch.randn(3, 10, device='cuda')
loss = model(inp).sum()
# All autograd saved tensors should be offloaded to CPU.
offload_verified = False
def dfs(grad_fn):
for e in dir(grad_fn):
if not e.startswith('_saved_'):
continue
saved = getattr(grad_fn, e)
if isinstance(saved, torch.Tensor):
self.assertEqual(torch.device("cpu"), saved.device)
nonlocal offload_verified
offload_verified = True
if hasattr(grad_fn, 'next_functions'):
for next_grad_fn, _ in grad_fn.next_functions:
dfs(next_grad_fn)
dfs(loss.grad_fn)
self.assertTrue(offload_verified)
torch.autograd.graph.saved_tensors_hooks.__init__ = orig_init
if __name__ == "__main__":
run_tests()