[state_dict] Move _gather_state_dict to dcp module (#112835)

This api is getting used by more than just FSDP. This PR moves it to DCP module.

Differential Revision: [D50962966](https://our.internmc.facebook.com/intern/diff/D50962966/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112835
Approved by: https://github.com/wz337
This commit is contained in:
Chien-Chin Huang 2023-11-07 14:20:40 -08:00 committed by PyTorch MergeBot
parent d98182e34e
commit a66f2a1b99
10 changed files with 143 additions and 120 deletions

View File

@ -9,8 +9,8 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import fully_shard
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer

View File

@ -4,8 +4,8 @@ import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
from torch.distributed.checkpoint._state_dict_utils import _all_gather_sharded_tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._shard_utils import _all_gather_sharded_tensor
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module

View File

@ -11,6 +11,7 @@ import torch.nn as nn
from torch.distributed._composable import fully_shard, replicate
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
_patch_optimizer_state_dict,
@ -23,7 +24,6 @@ from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.nn.parallel import DistributedDataParallel as DDP

View File

@ -0,0 +1,41 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
class TestStateDictUtils(DTensorTestBase):
@property
def world_size(self):
return 2
@with_comms
@skip_if_lt_x_gpu(2)
def test_gather_state_dict_dtensor(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
torch.random.manual_seed(dist.get_rank())
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
state_dict = {"dtensor": dist_tensor}
gathered_state_dict = _gather_state_dict(state_dict)
expected_gathered_dtensor = funcol.all_gather_tensor(
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
if __name__ == "__main__":
run_tests()

View File

@ -14,8 +14,8 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE,
apply_activation_checkpointing,
)
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,

View File

@ -21,6 +21,10 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.checkpoint._state_dict_utils import (
_all_gather_sharded_tensor,
_gather_state_dict,
)
from torch.distributed.fsdp import (
CPUOffload,
FullStateDictConfig,
@ -30,10 +34,6 @@ from torch.distributed.fsdp import (
ShardedStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp._shard_utils import (
_all_gather_sharded_tensor,
_gather_state_dict,
)
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer

View File

@ -1,16 +1,11 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import (
_create_chunk_dtensor,
_create_chunk_sharded_tensor,
_gather_state_dict,
)
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
@ -76,22 +71,6 @@ class TestShardUtilsDistributedDTensor(DTensorTestBase):
else:
self.assertEqual(self.rank >= len(tensor_chunks), True)
@with_comms
@skip_if_lt_x_gpu(2)
def test_gather_state_dict_dtensor(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
torch.random.manual_seed(dist.get_rank())
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
state_dict = {"dtensor": dist_tensor}
gathered_state_dict = _gather_state_dict(state_dict)
expected_gathered_dtensor = funcol.all_gather_tensor(
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,92 @@
import math
from typing import Any, Dict, Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate
def _all_gather_sharded_tensor(
sharded_tensor: ShardedTensor,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if pg is None:
pg = distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
shards = sharded_tensor.local_shards()
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
pg_device = (
distributed_c10d._get_pg_default_device(pg) if device is None else device
)
if shards:
local_tensor = shards[0].tensor.flatten()
if local_tensor.device.type != pg_device.type:
local_tensor = local_tensor.to(pg_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
local_tensor = torch.zeros(
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
)
tensor = torch.empty(
chunk_size * world_size,
dtype=local_tensor.dtype,
device=pg_device,
)
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
return tensor
def _gather_state_dict(
state_dict: Dict[str, Any],
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict.
"""
new_state_dict = {}
for key, value in state_dict.items():
if isinstance(value, ShardedTensor):
# ShardedTensor does not seem to record the original device type.
# So if the tensor is moved to CPU, we won't know the original type.
# As a result, we have to rely on the user to tell us the correct one.
output_tensor = _all_gather_sharded_tensor(value, pg, device)
local_shard_device = (
value.local_shards()[0].tensor.device
if value.local_shards()
else torch.device("cpu")
)
if output_tensor.device != local_shard_device:
value = output_tensor.to(local_shard_device)
else:
value = output_tensor
elif isinstance(value, DTensor):
if value.device != value.device_mesh.device_type:
value = value.to(value.device_mesh.device_type)
# FSDP all_gather: [Shard(0)] -> [Replicate()]
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
# 2D FSDP + TP all_gather:
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
placements = [Replicate() for _ in value.placements]
value = value.redistribute(
device_mesh=value.device_mesh,
placements=placements,
)
value = value.to_local()
elif isinstance(value, dict):
value = _gather_state_dict(value, pg, device)
new_state_dict[key] = value
return new_state_dict

View File

@ -25,6 +25,7 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.checkpoint._state_dict_utils import _gather_state_dict
from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
@ -45,7 +46,6 @@ from torch.distributed.fsdp._runtime_utils import (
_lazy_init,
_reset_flat_param_grad_info_if_needed,
)
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.utils._pytree import tree_map_only

View File

@ -1,11 +1,10 @@
import copy
import itertools
import math
from typing import Any, Dict, Optional
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import (
Shard,
@ -15,94 +14,6 @@ from torch.distributed._shard.sharded_tensor import (
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
from torch.distributed.fsdp._debug_utils import SimpleProfiler
def _all_gather_sharded_tensor(
sharded_tensor: ShardedTensor,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if pg is None:
pg = distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
shards = sharded_tensor.local_shards()
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
pg_device = (
distributed_c10d._get_pg_default_device(pg) if device is None else device
)
if shards:
local_tensor = shards[0].tensor.flatten()
with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
if local_tensor.device.type != pg_device.type:
local_tensor = local_tensor.to(pg_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
local_tensor = torch.zeros(
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
)
tensor = torch.empty(
chunk_size * world_size,
dtype=local_tensor.dtype,
device=pg_device,
)
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
return tensor
# TODO: Make this API work for both FSDP, and 2D. Move it outside of FSDP.
# External users are interesting in using this API.
def _gather_state_dict(
state_dict: Dict[str, Any],
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict.
"""
new_state_dict = {}
for key, value in state_dict.items():
if isinstance(value, ShardedTensor):
# ShardedTensor does not seem to record the original device type.
# So if the tensor is moved to CPU, we won't know the original type.
# As a result, we have to rely on the user to tell us the correct one.
output_tensor = _all_gather_sharded_tensor(value, pg, device)
local_shard_device = (
value.local_shards()[0].tensor.device
if value.local_shards()
else torch.device("cpu")
)
with SimpleProfiler.profile(SimpleProfiler.Type.H2D):
if output_tensor.device != local_shard_device:
value = output_tensor.to(local_shard_device)
else:
value = output_tensor
elif isinstance(value, DTensor):
if value.device != value.device_mesh.device_type:
value = value.to(value.device_mesh.device_type)
# FSDP all_gather: [Shard(0)] -> [Replicate()]
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
# 2D FSDP + TP all_gather:
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
placements = [Replicate() for _ in value.placements]
value = value.redistribute(
device_mesh=value.device_mesh,
placements=placements,
)
value = value.to_local()
elif isinstance(value, dict):
value = _gather_state_dict(value, pg, device)
new_state_dict[key] = value
return new_state_dict
def _get_remote_device_str(rank, device_type, num_devices_per_node):