mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the `torch.distributed._tensor`, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113 Approved by: https://github.com/XilunWu ghstack dependencies: #133305, #133306
138 lines
4.5 KiB
Python
138 lines
4.5 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import itertools
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._utils import _get_device_module
|
|
from torch.distributed import distributed_c10d
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
ShardedTensor,
|
|
ShardedTensorMetadata,
|
|
TensorProperties,
|
|
)
|
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
|
|
|
|
|
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
|
if device_type.lower() == "cpu":
|
|
return f"rank:{rank}/{device_type}"
|
|
elif device_type.lower() == "hpu":
|
|
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
|
|
else:
|
|
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
|
|
|
|
|
|
def _create_chunk_sharded_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> ShardedTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension. The local rank will gets its
|
|
corresponding chunk as the local shard to create a ShardedTensor.
|
|
"""
|
|
chunks = tensor.chunk(world_size, dim=0)
|
|
if len(chunks) > rank:
|
|
local_shard = chunks[rank].clone()
|
|
offsets = [0 for _ in tensor.size()]
|
|
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
|
|
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
|
|
else:
|
|
local_shards = []
|
|
|
|
# Create a ShardedTensor without invoking communication.
|
|
chunk_sizes = [list(chunk.size()) for chunk in chunks]
|
|
dim0_offsets = [0] + list(
|
|
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
|
|
)[:-1]
|
|
offsets = [0] * (len(chunk_sizes[0]) - 1)
|
|
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
|
|
device_type = (
|
|
distributed_c10d._get_pg_default_device(pg).type
|
|
if device is None
|
|
else device.type
|
|
)
|
|
placements = [
|
|
_get_remote_device_str(
|
|
dist.get_global_rank(pg, r),
|
|
device_type,
|
|
num_devices_per_node,
|
|
)
|
|
for r in range(len(chunk_sizes))
|
|
]
|
|
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
|
|
shard_metadata = [
|
|
ShardMetadata(offset, size, placement)
|
|
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
|
|
]
|
|
sharded_tensor_metadata = ShardedTensorMetadata(
|
|
shards_metadata=shard_metadata,
|
|
size=tensor.size(),
|
|
tensor_properties=TensorProperties(
|
|
dtype=tensor.dtype,
|
|
layout=tensor.layout,
|
|
requires_grad=False,
|
|
memory_format=torch.contiguous_format,
|
|
pin_memory=tensor.is_pinned(),
|
|
),
|
|
)
|
|
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
|
|
)
|
|
|
|
|
|
def _create_chunk_dtensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension. The local rank will gets its
|
|
corresponding chunk as the local tensor to create a DTensor.
|
|
"""
|
|
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
|
|
tensor = tensor.clone().detach()
|
|
|
|
# FSDP placements: [Shard(0)]
|
|
# HSDP placements: [Replicate(), Shard(0)]
|
|
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
|
|
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
|
|
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
|
|
|
|
return DTensor.from_local(
|
|
tensor, device_mesh, replicate_placements, run_check=False
|
|
).redistribute(
|
|
placements=shard_placements,
|
|
)
|
|
|
|
|
|
def _all_gather_dtensor(
|
|
tensor: DTensor,
|
|
root_mesh: Optional[DeviceMesh],
|
|
) -> torch.Tensor:
|
|
"""
|
|
All gather a DTensor in its sharded dimension and return the local tensor.
|
|
"""
|
|
assert (
|
|
root_mesh == tensor.device_mesh
|
|
), "The device mesh of a tensor should be a root mesh."
|
|
|
|
placements = list(copy.deepcopy(tensor.placements))
|
|
# FSDP placements: [Shard(0)] -> [Replicate()]
|
|
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
|
placements[-1] = Replicate()
|
|
tensor = tensor.redistribute(
|
|
device_mesh=tensor.device_mesh,
|
|
placements=placements,
|
|
)
|
|
|
|
return tensor.to_local()
|