# Owner(s): ["oncall: distributed"] from copy import deepcopy from functools import partial import torch import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, apply_activation_checkpointing, CheckpointWrapper, CheckpointImpl ) from torch.utils.checkpoint import checkpoint from torch.testing._internal.common_utils import ( run_tests, TestCase, ) import unittest class CheckpointWrapperTest(TestCase): def setUp(self): super().setUp() def test_load_activation_checkpointed_module(self): lin = nn.Linear(10, 10, bias=False) lin = checkpoint_wrapper( lin, checkpoint_fn=checkpoint, # checkpoint kwargs use_reentrant=True, preserve_rng_state=False, ) state_dict = deepcopy(lin.state_dict()) # Load into non-checkpoint wrapped linear module lin_new = nn.Linear(10, 10, bias=False) lin_new.load_state_dict(state_dict) for p1, p2 in zip(lin.parameters(), lin_new.parameters()): self.assertEqual(p1, p2) self.assertTrue(torch.allclose(p1, p2)) # Load non-checkpoint wrapped module into checkpoint wrapped one # Make params different for p in lin_new.parameters(): with torch.no_grad(): p.add_(0.5) state_dict = deepcopy(lin_new.state_dict()) # Verify checkpoint wrapped linear can load unwrapped linear lin.load_state_dict(state_dict) for p1, p2 in zip(lin.parameters(), lin_new.parameters()): self.assertEqual(p1, p2) def test_checkpoint_wrapper_kwarg_support(self): class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(10, 10) def forward(self, a, b, c=None, d=None, **kwargs): return ( self.lin(a), self.lin(b), self.lin(c), self.lin(d) ) for wrapper in [ partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT), partial(checkpoint_wrapper, offload_to_cpu=True), ]: with self.subTest(wrapper=wrapper): model = wrapper(MyModel()) self.assertTrue(isinstance(model, CheckpointWrapper)) # Verify kwargs can be passed in inp = torch.ones(4, 10, requires_grad=True) out = model(inp, inp, c=inp, d=inp, e=inp, f=inp) self.assertTrue(isinstance(out, tuple)) self.assertEqual(4, len(out)) # Without kwargs should have equivalent gradient requirements. out_no_kwarg = model(inp, inp, inp, inp) for t1, t2 in zip(out_no_kwarg, out): self.assertEqual(t1, t2) self.assertEqual(t1.requires_grad, t2.requires_grad) # Test model that enforces kwarg inputs class ModelEnforceKwarg(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(10, 10) def forward(self, *, a=None, b=None): return (self.lin(a), self.lin(b)) model = checkpoint_wrapper( ModelEnforceKwarg(), checkpoint_impl=CheckpointImpl.REENTRANT ) inp = torch.ones(4, 10, requires_grad=True) out = model(a=inp, b=inp) self.assertEqual(2, len(out)) @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") def test_checkpoint_wrapper_parity(self): """ Tests that using checkpoint_wrapper or the functional torch.utils.checkpoint (with the same reentrant config) results in the same maximum memory usage, i.e. they are equivalent memory usage wise. """ class Model(nn.Module): def __init__( self, n: int, use_cp: bool, use_wrapper: bool = False, use_reentrant: bool = True ): super().__init__() self.layers = nn.ModuleList() self.n = n self.use_cp = use_cp self.use_wrapper = use_wrapper self.use_reentrant = use_reentrant wrp = partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT if use_reentrant else CheckpointImpl.NO_REENTRANT ) for i in range(self.n): l = nn.Sequential(nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)) use_checkpoint_wrapper = self.use_wrapper if use_checkpoint_wrapper: l = wrp(l) self.layers.append(l) def forward(self, x): for i in range(self.n): if ( self.use_wrapper or not self.use_cp ): x = self.layers[i](x) else: x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant) return x def test(use_checkpointing, use_wrapper, use_reentrant): a = Model(8, use_checkpointing, use_wrapper=use_wrapper, use_reentrant=use_reentrant).cuda() x = torch.randn(10000, 256, requires_grad=True).cuda() torch.cuda.reset_peak_memory_stats() loss = a(x).sum() loss.backward() return torch.cuda.max_memory_allocated() functional_no_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=False) wrapper_no_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=False) self.assertEqual(functional_no_reentrant, wrapper_no_reentrant) functional_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=True) wrapper_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=True) self.assertEqual(functional_reentrant, wrapper_reentrant) def test_forward_missing_attributes(self): lin = nn.Linear(1, 1) m = nn.Sequential(lin, lin) wrapped = CheckpointWrapper(m) # Test indexing is forwarded self.assertEqual(wrapped[0], lin) # Test missing attributes are forwarded. m._foo = 'bar' self.assertEqual(wrapped._foo, 'bar') def test_apply_activation_checkpointing(self): """ Ensures that `apply_activation_checkpointing` can be used to swap modules for their checkpoint-wrapped counterparts given a model. """ class LinearWithBatchNorm(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(10, 10) self.bn = nn.BatchNorm1d(10) self.nested_linear = nn.Sequential(nn.Linear(10, 10)) def forward(self, x): return self.bn(self.nested_linear(self.lin(x))) class MyModel(nn.Module): def __init__(self): super().__init__() self.seq = nn.Sequential( LinearWithBatchNorm(), LinearWithBatchNorm(), LinearWithBatchNorm() ) def forward(self, x): return self.seq(x) def check_fn(l): return isinstance(l, nn.Linear) n_linear = None for wrapper in [ partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT), ]: model = MyModel() if n_linear is None: n_linear = sum( 1 if isinstance(x, nn.Linear) else 0 for x in model.modules() ) with self.subTest(wrapper=wrapper): apply_activation_checkpointing( model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn ) n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules()) n_checkpointed = sum(1 if isinstance(x, CheckpointWrapper) else 0 for x in model.modules()) self.assertEqual(n_checkpointed, n_linear_wrapped) self.assertEqual(n_linear, n_linear_wrapped) for j in range(3): self.assertTrue(isinstance(model.seq[j].lin, CheckpointWrapper)) self.assertTrue(isinstance(model.seq[j].nested_linear[0], CheckpointWrapper)) inp = torch.randn(4, 10, requires_grad=True) for i in range(6): # Kwarg input loss = model(x=inp).sum() self.assertTrue(loss.requires_grad) loss.backward() # ensure checkpointed part of model has gradients for j in range(3): weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]: self.assertTrue(param.requires_grad) self.assertFalse(param.grad is None) def test_fqn(self): lin = nn.Linear(10, 10, bias=False) lin = checkpoint_wrapper(lin) state_dict = lin.state_dict() for fqn, _ in lin.named_parameters(): self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") def test_checkpoint_wrapper_cpu_offload(self): model = nn.Sequential( nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), ).cuda() # Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for # testing, otherwise the tensor access during the DFS will cause orig # unpack to run, transferring the tensor back to GPU. def patched_init(saved_tensor_hook_obj, pack_hook, _): saved_tensor_hook_obj.pack_hook = pack_hook def testing_cpu_offload_unpack_hook(packed): _, tensor = packed return tensor saved_tensor_hook_obj.unpack_hook = testing_cpu_offload_unpack_hook orig_init = torch.autograd.graph.saved_tensors_hooks.__init__ torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init model = checkpoint_wrapper(model, offload_to_cpu=True) inp = torch.randn(3, 10, device='cuda') loss = model(inp).sum() # All autograd saved tensors should be offloaded to CPU. offload_verified = False def dfs(grad_fn): for e in dir(grad_fn): if not e.startswith('_saved_'): continue saved = getattr(grad_fn, e) if isinstance(saved, torch.Tensor): self.assertEqual(torch.device("cpu"), saved.device) nonlocal offload_verified offload_verified = True if hasattr(grad_fn, 'next_functions'): for next_grad_fn, _ in grad_fn.next_functions: dfs(next_grad_fn) dfs(loss.grad_fn) self.assertTrue(offload_verified) torch.autograd.graph.saved_tensors_hooks.__init__ = orig_init if __name__ == "__main__": run_tests()