mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CheckpointWrapper] Replace generic mod prefix (#79830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830 Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
This commit is contained in:
parent
4b6ba340e2
commit
2ede28724d
|
|
@ -168,10 +168,10 @@ class CheckpointWrapperTest(TestCase):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# ensure checkpointed part of model has gradients
|
# ensure checkpointed part of model has gradients
|
||||||
for j in range(3):
|
for j in range(3):
|
||||||
weight_lin = model.seq[j].lin.mod.weight
|
weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
|
||||||
bias_lin = model.seq[j].lin.mod.bias
|
bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
|
||||||
weight_nested_lin = model.seq[j].nested_linear[0].mod.weight
|
weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight
|
||||||
bias_nested_lin = model.seq[j].nested_linear[0].mod.bias
|
bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias
|
||||||
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
|
for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]:
|
||||||
self.assertTrue(param.requires_grad)
|
self.assertTrue(param.requires_grad)
|
||||||
self.assertFalse(param.grad is None)
|
self.assertFalse(param.grad is None)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
_CHECKPOINT_PREFIX = "mod"
|
_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
|
||||||
|
|
||||||
class CheckpointImpl(Enum):
|
class CheckpointImpl(Enum):
|
||||||
REENTRANT = auto()
|
REENTRANT = auto()
|
||||||
|
|
@ -28,7 +28,7 @@ class CheckpointWrapper(torch.nn.Module):
|
||||||
offload_to_cpu: bool = False,
|
offload_to_cpu: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mod = mod
|
self._checkpoint_wrapped_module = mod
|
||||||
self.checkpoint_impl = checkpoint_impl
|
self.checkpoint_impl = checkpoint_impl
|
||||||
self.offload_to_cpu = offload_to_cpu
|
self.offload_to_cpu = offload_to_cpu
|
||||||
# state_dict post hook to remove prefix to allow loading into a
|
# state_dict post hook to remove prefix to allow loading into a
|
||||||
|
|
@ -45,17 +45,17 @@ class CheckpointWrapper(torch.nn.Module):
|
||||||
try:
|
try:
|
||||||
return super().__getattr__(name) # defer to nn.Module's logic
|
return super().__getattr__(name) # defer to nn.Module's logic
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return getattr(self.mod, name)
|
return getattr(self._checkpoint_wrapped_module, name)
|
||||||
|
|
||||||
def __getitem__(self, key: int) -> Any:
|
def __getitem__(self, key: int) -> Any:
|
||||||
"""Forward indexing calls in case the module is a nn.Sequential."""
|
"""Forward indexing calls in case the module is a nn.Sequential."""
|
||||||
return self.mod.__getitem__(key) # type: ignore[operator]
|
return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator]
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
|
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
|
||||||
with offload_mgr: # type: ignore[attr-defined]
|
with offload_mgr: # type: ignore[attr-defined]
|
||||||
return checkpoint(
|
return checkpoint(
|
||||||
self.mod,
|
self._checkpoint_wrapped_module,
|
||||||
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
|
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user