mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CheckpointWrapper] Replace generic mod prefix (#79830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830 Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
This commit is contained in:
parent
4b6ba340e2
commit
2ede28724d
|
|
@ -168,10 +168,10 @@ class CheckpointWrapperTest(TestCase):
|
|||
loss.backward()
|
||||
# ensure checkpointed part of model has gradients
|
||||
for j in range(3):
|
||||
weight_lin = model.seq[j].lin.mod.weight
|
||||
bias_lin = model.seq[j].lin.mod.bias
|
||||
weight_nested_lin = model.seq[j].nested_linear[0].mod.weight
|
||||
bias_nested_lin = model.seq[j].nested_linear[0].mod.bias
|
||||
weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
|
||||
bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
|
||||
weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight
|
||||
bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias
|
||||
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
|
||||
self.assertTrue(param.requires_grad)
|
||||
self.assertFalse(param.grad is None)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||
from typing import Dict, Any
|
||||
from functools import partial
|
||||
|
||||
_CHECKPOINT_PREFIX = "mod"
|
||||
_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
|
||||
|
||||
class CheckpointImpl(Enum):
|
||||
REENTRANT = auto()
|
||||
|
|
@ -28,7 +28,7 @@ class CheckpointWrapper(torch.nn.Module):
|
|||
offload_to_cpu: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.mod = mod
|
||||
self._checkpoint_wrapped_module = mod
|
||||
self.checkpoint_impl = checkpoint_impl
|
||||
self.offload_to_cpu = offload_to_cpu
|
||||
# state_dict post hook to remove prefix to allow loading into a
|
||||
|
|
@ -45,17 +45,17 @@ class CheckpointWrapper(torch.nn.Module):
|
|||
try:
|
||||
return super().__getattr__(name) # defer to nn.Module's logic
|
||||
except AttributeError:
|
||||
return getattr(self.mod, name)
|
||||
return getattr(self._checkpoint_wrapped_module, name)
|
||||
|
||||
def __getitem__(self, key: int) -> Any:
|
||||
"""Forward indexing calls in case the module is a nn.Sequential."""
|
||||
return self.mod.__getitem__(key) # type: ignore[operator]
|
||||
return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator]
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
|
||||
with offload_mgr: # type: ignore[attr-defined]
|
||||
return checkpoint(
|
||||
self.mod,
|
||||
self._checkpoint_wrapped_module,
|
||||
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
|
||||
*args,
|
||||
**kwargs,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user