[FSDP2] Make module-to-state mapping use weakrefs (#139650)

Without this, `del model` does not free memory of a module with FSDP2 applied.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139650
Approved by: https://github.com/yf225
This commit is contained in:
Andrew Gu 2024-11-04 10:59:50 -08:00 committed by PyTorch MergeBot
parent 5008d15ae9
commit 9039fbb47e
2 changed files with 42 additions and 4 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import functools
import gc
import torch
from torch.distributed._composable.fsdp import (
@ -197,6 +198,36 @@ class TestFullyShardMemory(FSDPTest):
expected_mem_mb += (2 * model_sharded_numel) * 4 / 1e6 + buffer_mb
self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb)
@skip_if_lt_x_gpu(2)
def test_fully_shard_del_memory(self):
base_mem_mb = self._get_peak_active_memory_mb()
vocab_size = 32
model_args = ModelArgs(
vocab_size=vocab_size, n_layers=3, dim=768, n_heads=12, weight_tying=False
)
model = Transformer(model_args)
# Initializing the model on CPU should not change the GPU memory usage
post_model_init_mem_mb = self._get_peak_active_memory_mb()
self.assertEqual(base_mem_mb, post_model_init_mem_mb)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
unsharded_numel = sum(p.numel() for p in model.parameters())
sharded_numel = unsharded_numel // self.world_size
buffer_mb = 4
mem_mb = self._get_curr_active_memory_mb()
expected_mb = sharded_numel * 4 / 1e6 + buffer_mb
self.assertLessEqual(mem_mb - base_mem_mb, expected_mb)
# Deleting the model should free all of the FSDP-managed GPU memory
del model
# Manually call garbage collection since there are ref cycles in FSDP
gc.collect()
mem_mb = self._get_curr_active_memory_mb()
self.assertEqual(mem_mb, base_mem_mb)
def _get_peak_active_memory_mb(self) -> int:
mem_stats = torch.cuda.memory_stats()
return round(mem_stats["active_bytes.all.peak"] / 1e6)

View File

@ -1,4 +1,5 @@
from typing import cast, Dict, Optional
import weakref
from typing import cast, Optional
import torch.nn as nn
@ -7,13 +8,15 @@ class _State:
pass
_module_state_mapping: Dict[nn.Module, _State] = {}
_module_state_mapping: weakref.WeakKeyDictionary[
nn.Module, weakref.ReferenceType[_State]
] = weakref.WeakKeyDictionary()
def _insert_module_state(module: nn.Module, state: _State) -> None:
global _module_state_mapping
assert module not in _module_state_mapping, f"Inserting {module} more than once."
_module_state_mapping[module] = state
_module_state_mapping[module] = weakref.ref(state)
def _get_module_state(module: nn.Module) -> Optional[_State]:
@ -32,6 +35,10 @@ def _get_module_state(module: nn.Module) -> Optional[_State]:
else:
# https://github.com/pytorch/pytorch/issues/107054
if module in _module_state_mapping:
return _module_state_mapping[module]
state_ref = _module_state_mapping[module]
state = state_ref()
if state is None:
raise AssertionError("State has already been garbage collected")
return state
else:
return None