mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FSDP2] Make module-to-state mapping use weakrefs (#139650)
Without this, `del model` does not free memory of a module with FSDP2 applied. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139650 Approved by: https://github.com/yf225
This commit is contained in:
parent
5008d15ae9
commit
9039fbb47e
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import functools
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from torch.distributed._composable.fsdp import (
|
||||
|
|
@ -197,6 +198,36 @@ class TestFullyShardMemory(FSDPTest):
|
|||
expected_mem_mb += (2 * model_sharded_numel) * 4 / 1e6 + buffer_mb
|
||||
self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_del_memory(self):
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
vocab_size = 32
|
||||
model_args = ModelArgs(
|
||||
vocab_size=vocab_size, n_layers=3, dim=768, n_heads=12, weight_tying=False
|
||||
)
|
||||
model = Transformer(model_args)
|
||||
# Initializing the model on CPU should not change the GPU memory usage
|
||||
post_model_init_mem_mb = self._get_peak_active_memory_mb()
|
||||
self.assertEqual(base_mem_mb, post_model_init_mem_mb)
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
unsharded_numel = sum(p.numel() for p in model.parameters())
|
||||
sharded_numel = unsharded_numel // self.world_size
|
||||
buffer_mb = 4
|
||||
mem_mb = self._get_curr_active_memory_mb()
|
||||
expected_mb = sharded_numel * 4 / 1e6 + buffer_mb
|
||||
self.assertLessEqual(mem_mb - base_mem_mb, expected_mb)
|
||||
|
||||
# Deleting the model should free all of the FSDP-managed GPU memory
|
||||
del model
|
||||
# Manually call garbage collection since there are ref cycles in FSDP
|
||||
gc.collect()
|
||||
mem_mb = self._get_curr_active_memory_mb()
|
||||
self.assertEqual(mem_mb, base_mem_mb)
|
||||
|
||||
def _get_peak_active_memory_mb(self) -> int:
|
||||
mem_stats = torch.cuda.memory_stats()
|
||||
return round(mem_stats["active_bytes.all.peak"] / 1e6)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import cast, Dict, Optional
|
||||
import weakref
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
@ -7,13 +8,15 @@ class _State:
|
|||
pass
|
||||
|
||||
|
||||
_module_state_mapping: Dict[nn.Module, _State] = {}
|
||||
_module_state_mapping: weakref.WeakKeyDictionary[
|
||||
nn.Module, weakref.ReferenceType[_State]
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
def _insert_module_state(module: nn.Module, state: _State) -> None:
|
||||
global _module_state_mapping
|
||||
assert module not in _module_state_mapping, f"Inserting {module} more than once."
|
||||
_module_state_mapping[module] = state
|
||||
_module_state_mapping[module] = weakref.ref(state)
|
||||
|
||||
|
||||
def _get_module_state(module: nn.Module) -> Optional[_State]:
|
||||
|
|
@ -32,6 +35,10 @@ def _get_module_state(module: nn.Module) -> Optional[_State]:
|
|||
else:
|
||||
# https://github.com/pytorch/pytorch/issues/107054
|
||||
if module in _module_state_mapping:
|
||||
return _module_state_mapping[module]
|
||||
state_ref = _module_state_mapping[module]
|
||||
state = state_ref()
|
||||
if state is None:
|
||||
raise AssertionError("State has already been garbage collected")
|
||||
return state
|
||||
else:
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user