mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105 Approved by: https://github.com/XilunWu, https://github.com/fduwjj
357 lines
11 KiB
Python
357 lines
11 KiB
Python
import copy
|
|
import warnings
|
|
from typing import cast, List, NamedTuple, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import torch.distributed._shard.sharding_spec as shard_spec
|
|
import torch.distributed.distributed_c10d as c10d
|
|
|
|
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
ShardedTensor,
|
|
ShardedTensorMetadata,
|
|
TensorProperties,
|
|
)
|
|
|
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
|
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
DTensor as DistributedTensor,
|
|
Shard as DShard,
|
|
)
|
|
from torch.distributed._tensor.placement_types import Placement
|
|
|
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
|
|
|
from torch.distributed.remote_device import _remote_device
|
|
|
|
__all__ = ["enable_2d_with_fsdp"]
|
|
|
|
|
|
def enable_2d_with_fsdp() -> bool:
|
|
"""
|
|
The API registers the extension which is needed for Tensor Parallelism (TP)
|
|
to work with FullyShardedDataParallel (FSDP). We first parallelize parameters
|
|
within one module or sub_modules based on a parallelize_plan and will let FSDP
|
|
reshard the local tensor of distributed parameter which is essentially a DTensor.
|
|
|
|
Return:
|
|
A `bool` indicated whether extension registration succeeds or not.
|
|
"""
|
|
|
|
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.enable_2d_with_fsdp")
|
|
|
|
try:
|
|
from torch.distributed.fsdp._fsdp_extensions import (
|
|
_set_fsdp_extensions,
|
|
FSDPExtensions,
|
|
)
|
|
|
|
class DTensorExtensions(FSDPExtensions):
|
|
def pre_flatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]:
|
|
return _flatten_tensor(tensor)
|
|
|
|
def post_unflatten_transform(
|
|
self, tensor: torch.Tensor, param_extension: _STShardingInfo
|
|
) -> torch.Tensor:
|
|
return _unflatten_tensor(tensor, param_extension)
|
|
|
|
def chunk_tensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
) -> torch.Tensor:
|
|
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
|
|
|
|
def pre_load_state_dict_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
return _pre_load_state_dict(tensor)
|
|
|
|
_set_fsdp_extensions(DTensorExtensions())
|
|
return True
|
|
|
|
except BaseException as e:
|
|
warnings.warn(
|
|
"PyTorch doesn't have TensorFlattener extension point available"
|
|
"2D parallelism won't work with FSDP"
|
|
f"exception: {e}"
|
|
)
|
|
return False
|
|
|
|
|
|
class _STShardingInfo(NamedTuple):
|
|
""":class:`ShardedTensor` sharding information."""
|
|
|
|
sharding_spec: Optional[shard_spec.ShardingSpec]
|
|
global_size: Optional[torch.Size]
|
|
process_group: Optional[c10d.ProcessGroup]
|
|
device_mesh: Optional[DeviceMesh]
|
|
placements: Optional[List[Placement]]
|
|
|
|
|
|
def _get_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
|
|
device_mesh = tensor.device_mesh
|
|
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
|
|
placement = tensor.placements[0]
|
|
offsets = [0] * len(tensor.size())
|
|
num_chunks = device_mesh.size(dim=0)
|
|
|
|
if tensor.placements[0].is_shard():
|
|
shard_dim = cast(DShard, placement).dim
|
|
chunk_size = tensor.size(shard_dim) // num_chunks
|
|
offsets[shard_dim] = chunk_size
|
|
|
|
return (torch.Size(offsets), tensor._local_tensor.size())
|
|
|
|
|
|
def _get_box_for(tensor: DistributedTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
|
|
offsets, size = _get_box(tensor)
|
|
return (torch.Size([val * idx for val in offsets]), size)
|
|
|
|
|
|
def _get_local_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]:
|
|
device_mesh = tensor.device_mesh
|
|
coord = device_mesh.get_coordinate()
|
|
assert coord is not None
|
|
return _get_box_for(tensor, coord[0])
|
|
|
|
|
|
def _create_shard_md_from_dt(dt: DistributedTensor, current_rank: int) -> ShardMetadata:
|
|
mesh = dt.device_mesh
|
|
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
|
|
offsets, sizes = _get_local_box(dt)
|
|
return ShardMetadata(
|
|
shard_offsets=list(offsets),
|
|
shard_sizes=list(sizes),
|
|
placement=f"rank:{current_rank}/{dt._local_tensor.device}",
|
|
)
|
|
|
|
|
|
def _create_sharded_tensor_md_from_dt(
|
|
dt: DistributedTensor, dt_pg: c10d.ProcessGroup
|
|
) -> ShardedTensorMetadata:
|
|
# This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
|
|
# and yet has only one valid shard for the current rank.
|
|
|
|
shards_md = []
|
|
my_rank = dist.get_rank(dt_pg)
|
|
scapegoat_rank = 0 if my_rank > 0 else 1
|
|
|
|
if dt.placements[0].is_shard():
|
|
shard_count = dt_pg.size()
|
|
else:
|
|
shard_count = 1
|
|
|
|
for i in range(shard_count):
|
|
offsets, sizes = _get_box_for(dt, i)
|
|
shards_md.append(
|
|
ShardMetadata(
|
|
shard_offsets=list(offsets),
|
|
shard_sizes=list(sizes),
|
|
placement=(
|
|
f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
|
|
),
|
|
)
|
|
)
|
|
|
|
return ShardedTensorMetadata(
|
|
shards_metadata=shards_md,
|
|
size=dt.size(),
|
|
tensor_properties=TensorProperties(
|
|
dtype=dt.dtype,
|
|
layout=dt.layout,
|
|
requires_grad=dt.requires_grad,
|
|
# ignore memory_format and pin_memory as those are not supported by DT
|
|
),
|
|
)
|
|
|
|
|
|
def _get_dt_pg(dt: DistributedTensor) -> c10d.ProcessGroup:
|
|
mesh = dt.device_mesh
|
|
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
dim_groups = mesh.get_dim_groups()
|
|
assert isinstance(dim_groups, list)
|
|
return dim_groups[0]
|
|
|
|
|
|
def _rewrite_spec_if_needed(
|
|
spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
|
|
) -> shard_spec.ShardingSpec:
|
|
"""
|
|
Rewrite ``spec`` to match the device of ``tensor``.
|
|
|
|
FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
|
|
produces CUDA metadata, ST construction bombs.
|
|
"""
|
|
if not isinstance(spec, ChunkShardingSpec):
|
|
return spec
|
|
|
|
# let's see if we need
|
|
rewrite = False
|
|
for p in spec.placements:
|
|
p = cast(_remote_device, p)
|
|
if p.rank() == rank and p.device() != tensor.device:
|
|
rewrite = True
|
|
break
|
|
if rewrite:
|
|
spec = copy.deepcopy(spec)
|
|
for i, placement in enumerate(spec.placements):
|
|
placement = cast(_remote_device, placement)
|
|
if placement.rank() == rank and placement.device() != tensor.device:
|
|
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
|
|
|
|
return spec
|
|
|
|
|
|
def _flatten_tensor(
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]:
|
|
if type(tensor) is ShardedTensor:
|
|
return tensor.local_tensor(), _STShardingInfo(
|
|
tensor.sharding_spec(),
|
|
tensor.size(),
|
|
tensor._process_group,
|
|
None,
|
|
None,
|
|
)
|
|
elif type(tensor) is DistributedTensor:
|
|
tensor._local_tensor.requires_grad_()
|
|
return tensor._local_tensor, _STShardingInfo(
|
|
None,
|
|
None,
|
|
None,
|
|
tensor.device_mesh,
|
|
list(tensor.placements),
|
|
)
|
|
return tensor, None
|
|
|
|
|
|
def _unflatten_tensor(
|
|
tensor: torch.Tensor, sharding_info: _STShardingInfo
|
|
) -> torch.Tensor:
|
|
result: torch.Tensor
|
|
|
|
if sharding_info.sharding_spec is not None:
|
|
assert sharding_info.global_size is not None
|
|
result = ShardedTensor._init_from_local_tensor(
|
|
tensor,
|
|
_rewrite_spec_if_needed(
|
|
sharding_info.sharding_spec,
|
|
tensor,
|
|
dist.get_rank(sharding_info.process_group),
|
|
),
|
|
sharding_info.global_size,
|
|
process_group=cast(dist.ProcessGroup, sharding_info.process_group),
|
|
)
|
|
else:
|
|
result = DistributedTensor.from_local(
|
|
tensor,
|
|
device_mesh=sharding_info.device_mesh,
|
|
placements=sharding_info.placements,
|
|
run_check=False,
|
|
)
|
|
|
|
_set_fsdp_flattened(result)
|
|
return result
|
|
|
|
|
|
def _chunk_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
) -> torch.Tensor:
|
|
if type(tensor) is ShardedTensor:
|
|
assert len(tensor.local_shards()) == 1
|
|
|
|
inner_param = tensor.local_tensor()
|
|
inner_st = _create_chunk_sharded_tensor(
|
|
inner_param,
|
|
rank,
|
|
world_size,
|
|
num_devices_per_node,
|
|
pg,
|
|
)
|
|
|
|
outer_local_shard = tensor.local_shards()[0]
|
|
shards: List[Shard] = [
|
|
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
|
|
]
|
|
st_meta = copy.deepcopy(tensor.metadata())
|
|
st_meta.tensor_properties.requires_grad = False
|
|
|
|
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
shards,
|
|
sharded_tensor_metadata=st_meta,
|
|
process_group=tensor._process_group,
|
|
init_rrefs=False,
|
|
)
|
|
return st_outer
|
|
elif type(tensor) is DistributedTensor:
|
|
device_mesh = tensor.device_mesh
|
|
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
|
|
inner_param = tensor._local_tensor
|
|
|
|
inner_st = _create_chunk_sharded_tensor(
|
|
inner_param,
|
|
rank,
|
|
world_size,
|
|
torch.cuda.device_count(),
|
|
pg,
|
|
)
|
|
|
|
dt_pg = _get_dt_pg(tensor)
|
|
# We do this differently here, we create a ST with no local shards then patch it
|
|
shards = [
|
|
Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
|
|
]
|
|
|
|
st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
|
|
st_meta.tensor_properties.requires_grad = False
|
|
|
|
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
shards,
|
|
sharded_tensor_metadata=st_meta,
|
|
process_group=dt_pg,
|
|
init_rrefs=False,
|
|
)
|
|
|
|
return st_outer
|
|
else:
|
|
return _create_chunk_sharded_tensor(
|
|
tensor,
|
|
rank,
|
|
world_size,
|
|
num_devices_per_node,
|
|
pg,
|
|
)
|
|
|
|
|
|
def _pre_load_state_dict(
|
|
tensor: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, List[Shard]]:
|
|
shards = cast(ShardedTensor, tensor).local_shards()
|
|
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
|
|
inner_tensor = shards[0].tensor
|
|
shards = inner_tensor.local_shards() # pyre-ignore[16]
|
|
tensor = inner_tensor
|
|
|
|
return (tensor, shards if len(shards) > 0 else [])
|