mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[state_dict] Move _gather_state_dict to dcp module (#112835)
This api is getting used by more than just FSDP. This PR moves it to DCP module. Differential Revision: [D50962966](https://our.internmc.facebook.com/intern/diff/D50962966/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112835 Approved by: https://github.com/wz337
This commit is contained in:
parent
d98182e34e
commit
a66f2a1b99
|
|
@ -9,8 +9,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import fully_shard
|
||||
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
|
||||
from torch.distributed.fsdp._shard_utils import _gather_state_dict
|
||||
from torch.distributed.fsdp.api import ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import torch.distributed.checkpoint as dist_cp
|
|||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
|
||||
from torch.distributed.checkpoint._state_dict_utils import _all_gather_sharded_tensor
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._shard_utils import _all_gather_sharded_tensor
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
|
||||
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import torch.nn as nn
|
|||
from torch.distributed._composable import fully_shard, replicate
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
_patch_model_state_dict,
|
||||
_patch_optimizer_state_dict,
|
||||
|
|
@ -23,7 +24,6 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._shard_utils import _gather_state_dict
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.optim import _apply_optimizer_in_backward
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
|
|
|||
41
test/distributed/checkpoint/test_state_dict_utils.py
Normal file
41
test/distributed/checkpoint/test_state_dict_utils.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
|
||||
class TestStateDictUtils(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_gather_state_dict_dtensor(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
torch.random.manual_seed(dist.get_rank())
|
||||
local_tensor = torch.randn(3, 3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
state_dict = {"dtensor": dist_tensor}
|
||||
|
||||
gathered_state_dict = _gather_state_dict(state_dict)
|
||||
expected_gathered_dtensor = funcol.all_gather_tensor(
|
||||
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
)
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -14,8 +14,8 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|||
_CHECKPOINT_WRAPPED_MODULE,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._shard_utils import _gather_state_dict
|
||||
from torch.distributed.fsdp.api import ShardingStrategy
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
)
|
||||
from torch.distributed.checkpoint._state_dict_utils import (
|
||||
_all_gather_sharded_tensor,
|
||||
_gather_state_dict,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffload,
|
||||
FullStateDictConfig,
|
||||
|
|
@ -30,10 +34,6 @@ from torch.distributed.fsdp import (
|
|||
ShardedStateDictConfig,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp._shard_utils import (
|
||||
_all_gather_sharded_tensor,
|
||||
_gather_state_dict,
|
||||
)
|
||||
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
|
||||
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
|
||||
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
|
||||
|
|
|
|||
|
|
@ -1,16 +1,11 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.distributed.fsdp._shard_utils import (
|
||||
_create_chunk_dtensor,
|
||||
_create_chunk_sharded_tensor,
|
||||
_gather_state_dict,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
|
@ -76,22 +71,6 @@ class TestShardUtilsDistributedDTensor(DTensorTestBase):
|
|||
else:
|
||||
self.assertEqual(self.rank >= len(tensor_chunks), True)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_gather_state_dict_dtensor(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
torch.random.manual_seed(dist.get_rank())
|
||||
local_tensor = torch.randn(3, 3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
state_dict = {"dtensor": dist_tensor}
|
||||
|
||||
gathered_state_dict = _gather_state_dict(state_dict)
|
||||
expected_gathered_dtensor = funcol.all_gather_tensor(
|
||||
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
)
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
92
torch/distributed/checkpoint/_state_dict_utils.py
Normal file
92
torch/distributed/checkpoint/_state_dict_utils.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import math
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
|
||||
|
||||
def _all_gather_sharded_tensor(
|
||||
sharded_tensor: ShardedTensor,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if pg is None:
|
||||
pg = distributed_c10d._get_default_group()
|
||||
world_size = dist.get_world_size(pg)
|
||||
shards = sharded_tensor.local_shards()
|
||||
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
|
||||
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
|
||||
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
|
||||
pg_device = (
|
||||
distributed_c10d._get_pg_default_device(pg) if device is None else device
|
||||
)
|
||||
if shards:
|
||||
local_tensor = shards[0].tensor.flatten()
|
||||
if local_tensor.device.type != pg_device.type:
|
||||
local_tensor = local_tensor.to(pg_device)
|
||||
num_padding = chunk_size - local_tensor.numel()
|
||||
if num_padding > 0:
|
||||
local_tensor = F.pad(local_tensor, [0, num_padding])
|
||||
else:
|
||||
local_tensor = torch.zeros(
|
||||
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
|
||||
)
|
||||
|
||||
tensor = torch.empty(
|
||||
chunk_size * world_size,
|
||||
dtype=local_tensor.dtype,
|
||||
device=pg_device,
|
||||
)
|
||||
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
|
||||
|
||||
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
|
||||
return tensor
|
||||
|
||||
|
||||
def _gather_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, ShardedTensor):
|
||||
# ShardedTensor does not seem to record the original device type.
|
||||
# So if the tensor is moved to CPU, we won't know the original type.
|
||||
# As a result, we have to rely on the user to tell us the correct one.
|
||||
output_tensor = _all_gather_sharded_tensor(value, pg, device)
|
||||
local_shard_device = (
|
||||
value.local_shards()[0].tensor.device
|
||||
if value.local_shards()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
if output_tensor.device != local_shard_device:
|
||||
value = output_tensor.to(local_shard_device)
|
||||
else:
|
||||
value = output_tensor
|
||||
elif isinstance(value, DTensor):
|
||||
if value.device != value.device_mesh.device_type:
|
||||
value = value.to(value.device_mesh.device_type)
|
||||
# FSDP all_gather: [Shard(0)] -> [Replicate()]
|
||||
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
||||
# 2D FSDP + TP all_gather:
|
||||
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
|
||||
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
|
||||
placements = [Replicate() for _ in value.placements]
|
||||
value = value.redistribute(
|
||||
device_mesh=value.device_mesh,
|
||||
placements=placements,
|
||||
)
|
||||
value = value.to_local()
|
||||
elif isinstance(value, dict):
|
||||
value = _gather_state_dict(value, pg, device)
|
||||
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
|
@ -25,6 +25,7 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils
|
|||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed.distributed_c10d import _get_pg_default_device
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_apply_to_modules,
|
||||
|
|
@ -45,7 +46,6 @@ from torch.distributed.fsdp._runtime_utils import (
|
|||
_lazy_init,
|
||||
_reset_flat_param_grad_info_if_needed,
|
||||
)
|
||||
from torch.distributed.fsdp._shard_utils import _gather_state_dict
|
||||
from torch.distributed.fsdp.api import ShardingStrategy
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import copy
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
|
|
@ -15,94 +14,6 @@ from torch.distributed._shard.sharded_tensor import (
|
|||
)
|
||||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
||||
from torch.distributed.fsdp._debug_utils import SimpleProfiler
|
||||
|
||||
|
||||
def _all_gather_sharded_tensor(
|
||||
sharded_tensor: ShardedTensor,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if pg is None:
|
||||
pg = distributed_c10d._get_default_group()
|
||||
world_size = dist.get_world_size(pg)
|
||||
shards = sharded_tensor.local_shards()
|
||||
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
|
||||
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
|
||||
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
|
||||
pg_device = (
|
||||
distributed_c10d._get_pg_default_device(pg) if device is None else device
|
||||
)
|
||||
if shards:
|
||||
local_tensor = shards[0].tensor.flatten()
|
||||
with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
|
||||
if local_tensor.device.type != pg_device.type:
|
||||
local_tensor = local_tensor.to(pg_device)
|
||||
num_padding = chunk_size - local_tensor.numel()
|
||||
if num_padding > 0:
|
||||
local_tensor = F.pad(local_tensor, [0, num_padding])
|
||||
else:
|
||||
local_tensor = torch.zeros(
|
||||
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
|
||||
)
|
||||
|
||||
tensor = torch.empty(
|
||||
chunk_size * world_size,
|
||||
dtype=local_tensor.dtype,
|
||||
device=pg_device,
|
||||
)
|
||||
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
|
||||
|
||||
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
|
||||
return tensor
|
||||
|
||||
|
||||
# TODO: Make this API work for both FSDP, and 2D. Move it outside of FSDP.
|
||||
# External users are interesting in using this API.
|
||||
def _gather_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, ShardedTensor):
|
||||
# ShardedTensor does not seem to record the original device type.
|
||||
# So if the tensor is moved to CPU, we won't know the original type.
|
||||
# As a result, we have to rely on the user to tell us the correct one.
|
||||
output_tensor = _all_gather_sharded_tensor(value, pg, device)
|
||||
local_shard_device = (
|
||||
value.local_shards()[0].tensor.device
|
||||
if value.local_shards()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
with SimpleProfiler.profile(SimpleProfiler.Type.H2D):
|
||||
if output_tensor.device != local_shard_device:
|
||||
value = output_tensor.to(local_shard_device)
|
||||
else:
|
||||
value = output_tensor
|
||||
elif isinstance(value, DTensor):
|
||||
if value.device != value.device_mesh.device_type:
|
||||
value = value.to(value.device_mesh.device_type)
|
||||
# FSDP all_gather: [Shard(0)] -> [Replicate()]
|
||||
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
||||
# 2D FSDP + TP all_gather:
|
||||
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
|
||||
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
|
||||
placements = [Replicate() for _ in value.placements]
|
||||
value = value.redistribute(
|
||||
device_mesh=value.device_mesh,
|
||||
placements=placements,
|
||||
)
|
||||
value = value.to_local()
|
||||
elif isinstance(value, dict):
|
||||
value = _gather_state_dict(value, pg, device)
|
||||
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user