[CheckpointWrapper] Replace generic mod prefix (#79830)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
This commit is contained in:
Rohan Varma 2022-06-21 16:01:59 +00:00 committed by PyTorch MergeBot
parent 4b6ba340e2
commit 2ede28724d
2 changed files with 9 additions and 9 deletions

View File

@ -168,10 +168,10 @@ class CheckpointWrapperTest(TestCase):
loss.backward() loss.backward()
# ensure checkpointed part of model has gradients # ensure checkpointed part of model has gradients
for j in range(3): for j in range(3):
weight_lin = model.seq[j].lin.mod.weight weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
bias_lin = model.seq[j].lin.mod.bias bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
weight_nested_lin = model.seq[j].nested_linear[0].mod.weight weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight
bias_nested_lin = model.seq[j].nested_linear[0].mod.bias bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]: for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
self.assertTrue(param.requires_grad) self.assertTrue(param.requires_grad)
self.assertFalse(param.grad is None) self.assertFalse(param.grad is None)

View File

@ -10,7 +10,7 @@ import torch.nn as nn
from typing import Dict, Any from typing import Dict, Any
from functools import partial from functools import partial
_CHECKPOINT_PREFIX = "mod" _CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
class CheckpointImpl(Enum): class CheckpointImpl(Enum):
REENTRANT = auto() REENTRANT = auto()
@ -28,7 +28,7 @@ class CheckpointWrapper(torch.nn.Module):
offload_to_cpu: bool = False, offload_to_cpu: bool = False,
): ):
super().__init__() super().__init__()
self.mod = mod self._checkpoint_wrapped_module = mod
self.checkpoint_impl = checkpoint_impl self.checkpoint_impl = checkpoint_impl
self.offload_to_cpu = offload_to_cpu self.offload_to_cpu = offload_to_cpu
# state_dict post hook to remove prefix to allow loading into a # state_dict post hook to remove prefix to allow loading into a
@ -45,17 +45,17 @@ class CheckpointWrapper(torch.nn.Module):
try: try:
return super().__getattr__(name) # defer to nn.Module's logic return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError: except AttributeError:
return getattr(self.mod, name) return getattr(self._checkpoint_wrapped_module, name)
def __getitem__(self, key: int) -> Any: def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential.""" """Forward indexing calls in case the module is a nn.Sequential."""
return self.mod.__getitem__(key) # type: ignore[operator] return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator]
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress() offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
with offload_mgr: # type: ignore[attr-defined] with offload_mgr: # type: ignore[attr-defined]
return checkpoint( return checkpoint(
self.mod, self._checkpoint_wrapped_module,
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT), use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
*args, *args,
**kwargs, **kwargs,