mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141 Approved by: https://github.com/bobrenjc93
389 lines
13 KiB
Python
389 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
from typing import Any, cast, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._shard.sharding_spec as shard_spec
|
|
import torch.distributed.distributed_c10d as c10d
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
ShardedTensor,
|
|
ShardedTensorMetadata,
|
|
TensorProperties,
|
|
)
|
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
|
from torch.distributed.device_mesh import _mesh_resources
|
|
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
|
|
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
|
from torch.distributed.remote_device import _remote_device
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
|
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
|
_flatten_tensor,
|
|
_unflatten_tensor,
|
|
)
|
|
|
|
|
|
__all__ = ["DTensorExtensions"]
|
|
|
|
|
|
def _get_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]:
|
|
device_mesh = tensor.device_mesh
|
|
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
|
|
placement = tensor.placements[0]
|
|
offsets = [0] * len(tensor.size())
|
|
num_chunks = device_mesh.size(mesh_dim=0)
|
|
|
|
if tensor.placements[0].is_shard():
|
|
shard_dim = cast(DShard, placement).dim
|
|
chunk_size = tensor.size(shard_dim) // num_chunks
|
|
offsets[shard_dim] = chunk_size
|
|
|
|
return (torch.Size(offsets), tensor._local_tensor.size())
|
|
|
|
|
|
def _get_box_for(tensor: DTensor, idx: int) -> tuple[torch.Size, torch.Size]:
|
|
offsets, size = _get_box(tensor)
|
|
return (torch.Size([val * idx for val in offsets]), size)
|
|
|
|
|
|
def _get_local_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]:
|
|
device_mesh = tensor.device_mesh
|
|
coord = device_mesh.get_coordinate()
|
|
assert coord is not None
|
|
return _get_box_for(tensor, coord[0])
|
|
|
|
|
|
def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata:
|
|
mesh = dt.device_mesh
|
|
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
|
|
offsets, sizes = _get_local_box(dt)
|
|
return ShardMetadata(
|
|
shard_offsets=list(offsets),
|
|
shard_sizes=list(sizes),
|
|
placement=f"rank:{current_rank}/{dt._local_tensor.device}",
|
|
)
|
|
|
|
|
|
def _create_sharded_tensor_md_from_dt(
|
|
dt: DTensor, dt_pg: c10d.ProcessGroup
|
|
) -> ShardedTensorMetadata:
|
|
# This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
|
|
# and yet has only one valid shard for the current rank.
|
|
|
|
shards_md = []
|
|
my_rank = dist.get_rank(dt_pg)
|
|
scapegoat_rank = 0 if my_rank > 0 else 1
|
|
|
|
if dt.placements[0].is_shard():
|
|
shard_count = dt_pg.size()
|
|
else:
|
|
shard_count = 1
|
|
|
|
for i in range(shard_count):
|
|
offsets, sizes = _get_box_for(dt, i)
|
|
shards_md.append(
|
|
ShardMetadata(
|
|
shard_offsets=list(offsets),
|
|
shard_sizes=list(sizes),
|
|
placement=(
|
|
f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
|
|
),
|
|
)
|
|
)
|
|
|
|
return ShardedTensorMetadata(
|
|
shards_metadata=shards_md,
|
|
size=dt.size(),
|
|
tensor_properties=TensorProperties(
|
|
dtype=dt.dtype,
|
|
layout=dt.layout,
|
|
requires_grad=dt.requires_grad,
|
|
# ignore memory_format and pin_memory as those are not supported by DT
|
|
),
|
|
)
|
|
|
|
|
|
def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup:
|
|
mesh = dt.device_mesh
|
|
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
return mesh.get_group()
|
|
|
|
|
|
def _rewrite_spec_if_needed(
|
|
spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
|
|
) -> shard_spec.ShardingSpec:
|
|
"""
|
|
Rewrite ``spec`` to match the device of ``tensor``.
|
|
|
|
FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
|
|
produces CUDA metadata, ST construction bombs.
|
|
"""
|
|
if not isinstance(spec, ChunkShardingSpec):
|
|
return spec
|
|
|
|
# let's see if we need
|
|
rewrite = False
|
|
for p in spec.placements:
|
|
p = cast(_remote_device, p)
|
|
if p.rank() == rank and p.device() != tensor.device:
|
|
rewrite = True
|
|
break
|
|
if rewrite:
|
|
spec = copy.deepcopy(spec)
|
|
for i, placement in enumerate(spec.placements):
|
|
placement = cast(_remote_device, placement)
|
|
if placement.rank() == rank and placement.device() != tensor.device:
|
|
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
|
|
|
|
return spec
|
|
|
|
|
|
def _chunk_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
) -> torch.Tensor:
|
|
if type(tensor) is ShardedTensor:
|
|
assert len(tensor.local_shards()) == 1
|
|
|
|
inner_param = tensor.local_tensor()
|
|
inner_st = _create_chunk_sharded_tensor(
|
|
inner_param,
|
|
rank,
|
|
world_size,
|
|
num_devices_per_node,
|
|
pg,
|
|
)
|
|
|
|
outer_local_shard = tensor.local_shards()[0]
|
|
shards: list[Shard] = [
|
|
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
|
|
]
|
|
st_meta = copy.deepcopy(tensor.metadata())
|
|
st_meta.tensor_properties.requires_grad = False
|
|
|
|
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
shards,
|
|
sharded_tensor_metadata=st_meta,
|
|
process_group=tensor._process_group,
|
|
init_rrefs=False,
|
|
)
|
|
return st_outer
|
|
elif type(tensor) is DTensor:
|
|
device_mesh = tensor.device_mesh
|
|
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
|
|
|
|
inner_param = tensor._local_tensor
|
|
|
|
inner_st = _create_chunk_sharded_tensor(
|
|
inner_param,
|
|
rank,
|
|
world_size,
|
|
torch.accelerator.device_count(),
|
|
pg,
|
|
)
|
|
|
|
dt_pg = _get_dt_pg(tensor)
|
|
# We do this differently here, we create a ST with no local shards then patch it
|
|
shards = [
|
|
Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
|
|
]
|
|
|
|
st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
|
|
st_meta.tensor_properties.requires_grad = False
|
|
|
|
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
shards,
|
|
sharded_tensor_metadata=st_meta,
|
|
process_group=dt_pg,
|
|
init_rrefs=False,
|
|
)
|
|
|
|
return st_outer
|
|
else:
|
|
return _create_chunk_sharded_tensor(
|
|
tensor,
|
|
rank,
|
|
world_size,
|
|
num_devices_per_node,
|
|
pg,
|
|
)
|
|
|
|
|
|
def _chunk_dtensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension.
|
|
|
|
The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
|
|
"""
|
|
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
|
if root_mesh is None:
|
|
raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
|
|
if root_mesh.ndim < 2:
|
|
raise RuntimeError(
|
|
f"Found parent device_mesh of ndim={root_mesh.ndim},",
|
|
"but meshes must be at least 2D.",
|
|
)
|
|
|
|
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
|
|
tensor = tensor.detach().clone()
|
|
|
|
# When a layer is not involved in TP, then the tensor will not be a DTensor.
|
|
# e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer.
|
|
# e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer.
|
|
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):
|
|
# For tensors, it is replicated across tp dimension and sharded across FSDP dimension.
|
|
# TP is the inner dimension and FSDP is the outer dimension.
|
|
# Therefore, shard placements for tensor is (Shard(0), Replicate()).
|
|
replicate_placements = [Replicate() for _ in range(root_mesh.ndim)]
|
|
shard_placements = [Replicate() for _ in range(root_mesh.ndim)]
|
|
shard_placements[0] = DShard(0) # type: ignore[call-overload]
|
|
|
|
return DTensor.from_local(
|
|
tensor, root_mesh, replicate_placements, run_check=False
|
|
).redistribute(
|
|
device_mesh=root_mesh,
|
|
placements=shard_placements,
|
|
)
|
|
|
|
else:
|
|
tp_placements = tensor.placements
|
|
tp_placement = tp_placements[0]
|
|
|
|
tensor = tensor.to_local()
|
|
|
|
# For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension.
|
|
# TP is the inner dimension and FSDP is the outer dimension.
|
|
# Therefore, shard placements for tensor is (Shard(0), tp_placement).
|
|
# For higher dimensional meshes, it is replicated across other dimensions. For example, with
|
|
# HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement).
|
|
replicate_placements = [Replicate() for _ in range(root_mesh.ndim)]
|
|
replicate_placements[-1] = tp_placement # type: ignore[call-overload]
|
|
shard_placements = [Replicate() for i in range(root_mesh.ndim)] # type: ignore[misc]
|
|
shard_placements[-2] = DShard(0) # type: ignore[call-overload]
|
|
shard_placements[-1] = tp_placement # type: ignore[call-overload]
|
|
|
|
return DTensor.from_local(
|
|
tensor, root_mesh, replicate_placements, run_check=False
|
|
).redistribute(
|
|
device_mesh=root_mesh,
|
|
placements=shard_placements,
|
|
)
|
|
|
|
|
|
def _pre_load_state_dict(
|
|
tensor: torch.Tensor,
|
|
) -> tuple[torch.Tensor, list[Shard]]:
|
|
shards = cast(ShardedTensor, tensor).local_shards()
|
|
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
|
|
inner_tensor = shards[0].tensor
|
|
shards = inner_tensor.local_shards() # pyre-ignore[16]
|
|
tensor = inner_tensor
|
|
|
|
return (tensor, shards if len(shards) > 0 else [])
|
|
|
|
|
|
def _all_gather_dtensor(
|
|
tensor: DTensor,
|
|
parent_mesh: Optional[DeviceMesh],
|
|
) -> torch.Tensor:
|
|
"""All gather a DTensor in its FSDP dimension and return the local tensor."""
|
|
assert parent_mesh == tensor.device_mesh
|
|
|
|
placements = list(copy.deepcopy(tensor.placements))
|
|
# FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement]
|
|
# HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement]
|
|
for i in range(0, len(placements) - 1):
|
|
placements[i] = Replicate()
|
|
tensor = tensor.redistribute(
|
|
device_mesh=tensor.device_mesh,
|
|
placements=placements,
|
|
)
|
|
|
|
return tensor.to_local()
|
|
|
|
|
|
class DTensorExtensions(FSDPExtensions):
|
|
"""
|
|
DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP.
|
|
|
|
This is the implementation for FSDPExtensions defined in
|
|
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py
|
|
"""
|
|
|
|
def __init__(self, device_handle) -> None:
|
|
super().__init__()
|
|
self.compute_stream = None
|
|
self.device_handle = device_handle
|
|
# we have to use the dynamo disable this way to disable dynamo as the decorater way would
|
|
# trigger build failure with torch deploy...
|
|
self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign]
|
|
|
|
def pre_flatten_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> tuple[torch.Tensor, Optional[Any]]:
|
|
return _flatten_tensor(tensor)
|
|
|
|
def post_unflatten_transform(
|
|
self, tensor: torch.Tensor, param_extension: Any
|
|
) -> torch.Tensor:
|
|
stream = self.compute_stream or self.device_handle.current_stream()
|
|
with self.device_handle.stream(stream):
|
|
# runtime we put the unflattened tensor call on the compute stream since
|
|
# the unflattened tensor might contain computations in fwd/bwd where we
|
|
# need to sync properly.
|
|
# TODO: this is a short term fix and we should make the get_unflat_views
|
|
# directly happen in the compute stream.
|
|
result = _unflatten_tensor(
|
|
tensor,
|
|
param_extension,
|
|
device_handle=self.device_handle,
|
|
compute_stream=self.compute_stream,
|
|
)
|
|
_set_fsdp_flattened(result)
|
|
return result
|
|
|
|
def chunk_tensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
|
|
|
|
def chunk_dtensor(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
) -> torch.Tensor:
|
|
return _chunk_dtensor(tensor, rank, device_mesh)
|
|
|
|
def pre_load_state_dict_transform(
|
|
self,
|
|
tensor: torch.Tensor,
|
|
) -> tuple[torch.Tensor, list[Shard]]:
|
|
return _pre_load_state_dict(tensor)
|
|
|
|
def all_gather_dtensor(
|
|
self,
|
|
tensor: DTensor,
|
|
parent_mesh: Optional[DeviceMesh],
|
|
) -> torch.Tensor:
|
|
return _all_gather_dtensor(tensor, parent_mesh)
|