pytorch/torch/distributed/fsdp/_shard_utils.py
Wanchao Liang 2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00

138 lines
4.5 KiB
Python

# mypy: allow-untyped-defs
import copy
import itertools
import math
from typing import Optional
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
def _get_remote_device_str(rank, device_type, num_devices_per_node):
if device_type.lower() == "cpu":
return f"rank:{rank}/{device_type}"
elif device_type.lower() == "hpu":
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
else:
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
def _create_chunk_sharded_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> ShardedTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local shard to create a ShardedTensor.
"""
chunks = tensor.chunk(world_size, dim=0)
if len(chunks) > rank:
local_shard = chunks[rank].clone()
offsets = [0 for _ in tensor.size()]
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
else:
local_shards = []
# Create a ShardedTensor without invoking communication.
chunk_sizes = [list(chunk.size()) for chunk in chunks]
dim0_offsets = [0] + list(
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
)[:-1]
offsets = [0] * (len(chunk_sizes[0]) - 1)
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
device_type = (
distributed_c10d._get_pg_default_device(pg).type
if device is None
else device.type
)
placements = [
_get_remote_device_str(
dist.get_global_rank(pg, r),
device_type,
num_devices_per_node,
)
for r in range(len(chunk_sizes))
]
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
shard_metadata = [
ShardMetadata(offset, size, placement)
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
]
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=tensor.size(),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
)
def _create_chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> DTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local tensor to create a DTensor.
"""
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
tensor = tensor.clone().detach()
# FSDP placements: [Shard(0)]
# HSDP placements: [Replicate(), Shard(0)]
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
return DTensor.from_local(
tensor, device_mesh, replicate_placements, run_check=False
).redistribute(
placements=shard_placements,
)
def _all_gather_dtensor(
tensor: DTensor,
root_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
All gather a DTensor in its sharded dimension and return the local tensor.
"""
assert (
root_mesh == tensor.device_mesh
), "The device mesh of a tensor should be a root mesh."
placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()]
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
placements[-1] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)
return tensor.to_local()