From 2ede28724d45996842c60d06aee0218a3ab7062e Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Tue, 21 Jun 2022 16:01:59 +0000 Subject: [PATCH] [CheckpointWrapper] Replace generic mod prefix (#79830) Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830 Approved by: https://github.com/awgu, https://github.com/zhaojuanmao --- test/distributed/fsdp/test_checkpoint_wrapper.py | 8 ++++---- .../algorithms/_checkpoint/checkpoint_wrapper.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index 5f8c4f35b55..3c2b5957e44 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -168,10 +168,10 @@ class CheckpointWrapperTest(TestCase): loss.backward() # ensure checkpointed part of model has gradients for j in range(3): - weight_lin = model.seq[j].lin.mod.weight - bias_lin = model.seq[j].lin.mod.bias - weight_nested_lin = model.seq[j].nested_linear[0].mod.weight - bias_nested_lin = model.seq[j].nested_linear[0].mod.bias + weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight + bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias + weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight + bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]: self.assertTrue(param.requires_grad) self.assertFalse(param.grad is None) diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index ebdc2142251..e3734903112 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -10,7 +10,7 @@ import torch.nn as nn from typing import Dict, Any from functools import partial -_CHECKPOINT_PREFIX = "mod" +_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module" class CheckpointImpl(Enum): REENTRANT = auto() @@ -28,7 +28,7 @@ class CheckpointWrapper(torch.nn.Module): offload_to_cpu: bool = False, ): super().__init__() - self.mod = mod + self._checkpoint_wrapped_module = mod self.checkpoint_impl = checkpoint_impl self.offload_to_cpu = offload_to_cpu # state_dict post hook to remove prefix to allow loading into a @@ -45,17 +45,17 @@ class CheckpointWrapper(torch.nn.Module): try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: - return getattr(self.mod, name) + return getattr(self._checkpoint_wrapped_module, name) def __getitem__(self, key: int) -> Any: """Forward indexing calls in case the module is a nn.Sequential.""" - return self.mod.__getitem__(key) # type: ignore[operator] + return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] def forward(self, *args, **kwargs): offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress() with offload_mgr: # type: ignore[attr-defined] return checkpoint( - self.mod, + self._checkpoint_wrapped_module, use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT), *args, **kwargs,