[CheckpointWrapper] Replace generic mod prefix (#79830)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
This commit is contained in:
Rohan Varma 2022-06-21 16:01:59 +00:00 committed by PyTorch MergeBot
parent 4b6ba340e2
commit 2ede28724d
2 changed files with 9 additions and 9 deletions

View File

@ -168,10 +168,10 @@ class CheckpointWrapperTest(TestCase):
loss.backward()
# ensure checkpointed part of model has gradients
for j in range(3):
weight_lin = model.seq[j].lin.mod.weight
bias_lin = model.seq[j].lin.mod.bias
weight_nested_lin = model.seq[j].nested_linear[0].mod.weight
bias_nested_lin = model.seq[j].nested_linear[0].mod.bias
weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight
bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
self.assertTrue(param.requires_grad)
self.assertFalse(param.grad is None)

View File

@ -10,7 +10,7 @@ import torch.nn as nn
from typing import Dict, Any
from functools import partial
_CHECKPOINT_PREFIX = "mod"
_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
class CheckpointImpl(Enum):
REENTRANT = auto()
@ -28,7 +28,7 @@ class CheckpointWrapper(torch.nn.Module):
offload_to_cpu: bool = False,
):
super().__init__()
self.mod = mod
self._checkpoint_wrapped_module = mod
self.checkpoint_impl = checkpoint_impl
self.offload_to_cpu = offload_to_cpu
# state_dict post hook to remove prefix to allow loading into a
@ -45,17 +45,17 @@ class CheckpointWrapper(torch.nn.Module):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.mod, name)
return getattr(self._checkpoint_wrapped_module, name)
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.mod.__getitem__(key) # type: ignore[operator]
return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator]
def forward(self, *args, **kwargs):
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
with offload_mgr: # type: ignore[attr-defined]
return checkpoint(
self.mod,
self._checkpoint_wrapped_module,
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
*args,
**kwargs,