pytorch/torch/distributed/tensor/_redistribute.py
Dzmitry Huba 5e58420dff LocalTensor (#164537)
A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks.  A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally.  When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards.  A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.

NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!)  (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)

In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.

Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.

Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang
2025-10-12 20:06:41 +00:00

571 lines
22 KiB
Python

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import dataclasses
import logging
from collections import defaultdict
from collections.abc import Sequence
from functools import cache
from typing import cast, NamedTuple, Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._api as dtensor
from torch.distributed._functional_collectives import _are_we_tracing
from torch.distributed.tensor._dtensor_spec import (
DTensorSpec,
ShardOrder,
ShardOrderEntry,
TensorMeta,
)
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
from torch.utils._debug_mode import get_active_debug_mode
logger = logging.getLogger(__name__)
class _TransformInfo(NamedTuple):
mesh_dim: int
src_dst_placements: tuple[Placement, Placement]
# logical_shape on this mesh dimension
logical_shape: list[int]
# TODO(zpcore): complete the core algorithm of redistributing from source
# placement to target placement considering device ordering
class DTensorRedistributePlanner:
"""
This class is used to plan the collective calls to transform the local shard
of the DTensor from its current spec to the target spec.
"""
@dataclasses.dataclass(frozen=True, slots=True)
class DistState:
placements: tuple[Placement, ...]
tensor_dim_to_mesh_dim: ShardOrder
_hash: Optional[int] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
def __str__(self):
return DTensorSpec.format_shard_order_str(
self.placements,
self.tensor_dim_to_mesh_dim,
)
def __repr__(self):
return self.__str__()
def __post_init__(self):
# precompute hash after all attributes are set
object.__setattr__(
self,
"_hash",
self._compute_hash(),
)
def __hash__(self) -> int:
return self._hash if self._hash is not None else self._compute_hash()
def _compute_hash(self) -> int:
return hash(
(
self.placements,
self.tensor_dim_to_mesh_dim,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DTensorRedistributePlanner.DistState):
return False
if self._hash != other._hash:
return False
return (
self.placements,
self.tensor_dim_to_mesh_dim,
) == (
other.placements,
other.tensor_dim_to_mesh_dim,
)
@staticmethod
def _dict_to_ShardOrder(x: dict[int, list[int]]) -> ShardOrder:
"""Convert dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value))
for key, value in sorted(x.items())
if value
)
@staticmethod
def _ShardOrder_to_dict(x: ShardOrder) -> dict[int, list[int]]:
"""Convert ShardOrder to dict with tensor dim as key"""
tensor_mesh_dim_dict = defaultdict(list)
for entry in x:
tensor_mesh_dim_dict[entry.tensor_dim] = list(entry.mesh_dims)
return tensor_mesh_dim_dict
@staticmethod
def stringify_transform_infos(
mesh: DeviceMesh,
transform_infos: Sequence[_TransformInfo],
src_placement: tuple[Placement, ...],
src_shard_order: Optional[ShardOrder] = None,
) -> str:
"""
Generate a string representation of the sequence of state transitions
(placements and shard orders) as described by the given transform_info.
Args:
mesh: The DeviceMesh used for the redistribution.
transform_infos: A sequence of _TransformInfo objects describing each
transformation step.
src_placement: The initial tuple of Placement objects.
src_shard_order: (Optional) The initial ShardOrder representing
the mapping of tensor dimensions to mesh dimensions. If None,
the default shard order is computed from src_placement and mesh.
Returns:
A string showing the sequence of DistState transitions, separated by '->'.
"""
assert len(src_placement) == mesh.ndim
if src_shard_order is None:
src_shard_order = DTensorSpec.compute_default_shard_order(src_placement)
cur_placement = list(src_placement)
shard_order_dict = DTensorRedistributePlanner._ShardOrder_to_dict(
src_shard_order
)
cur_state = DTensorRedistributePlanner.DistState(
tuple(cur_placement), src_shard_order
)
state_list = [
cur_state,
]
for transform_info in transform_infos:
src_dim_placement, dst_dim_placement = transform_info.src_dst_placements
if src_dim_placement.is_shard():
src_dim = src_dim_placement.dim # type: ignore[attr-defined]
assert (
src_dim in shard_order_dict and len(shard_order_dict[src_dim]) > 0
)
shard_order_dict[src_dim].pop()
if dst_dim_placement.is_shard():
dst_dim = dst_dim_placement.dim # type: ignore[attr-defined]
if dst_dim not in shard_order_dict:
shard_order_dict[dst_dim] = []
shard_order_dict[dst_dim].append(transform_info.mesh_dim)
cur_placement[transform_info.mesh_dim] = dst_dim_placement
new_state = DTensorRedistributePlanner.DistState(
tuple(cur_placement),
DTensorRedistributePlanner._dict_to_ShardOrder(shard_order_dict),
)
state_list.append(new_state)
return "->".join([str(s) for s in state_list])
def _gen_transform_infos_non_cached(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> list[_TransformInfo]:
"""
Generate the transform infos from the source placements to the target placements.
To transform from source to target placement it might have multiple steps, i.e. it
might decompose Si -> Sj into Si -> R -> Sj.
This would detect if there're mis-aligned/nested shardings between src/dst placements.
E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)),
in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in
the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
the latter is the first sharding on tensor dimension 0.
"""
transform_infos: list[_TransformInfo] = []
device_mesh = src_spec.device_mesh
my_coordinate = device_mesh.get_coordinate()
assert my_coordinate is not None
# logical shape records the logic tensor shape on the mesh dimension
# this is useful to ensure uneven sharding gets correct output shape
initial_logical_shape = list(src_spec.shape)
mesh_dims_to_logical_shape = [initial_logical_shape]
if device_mesh.ndim == 1:
# if device_mesh is 1D, redistribute is a simple direct transformation
transform_infos.append(
_TransformInfo(
mesh_dim=0,
src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]),
logical_shape=initial_logical_shape,
)
)
return transform_infos
# Handle multi-dim device mesh placement redistribution
# First, we need to build the logical shape for each mesh dim
# for correct allgathering uneven shards on each mesh dim (with dynamic padding)
for i, src in enumerate(src_spec.placements):
current_logical_shape = mesh_dims_to_logical_shape[i]
if isinstance(src, Shard):
if i < device_mesh.ndim - 1:
# calculate and save the logical shape for this sharding
mesh_dim_size = device_mesh.size(mesh_dim=i)
local_shard_size, _ = src._local_shard_size_and_offset(
current_logical_shape[src.dim],
mesh_dim_size,
my_coordinate[i],
)
new_logical_shape = list(current_logical_shape)
new_logical_shape[src.dim] = local_shard_size
mesh_dims_to_logical_shape.append(new_logical_shape)
else:
mesh_dims_to_logical_shape.append(current_logical_shape)
# Next, we need to derive the transform infos from src to dst placements,
# here we use a greedy search with step by step state transformations
current_placements = list(src_spec.placements)
target_placements = list(dst_spec.placements)
if src_spec.num_shards > 1:
# If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec
# a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))).
# In those cases, we first traverse from inner placement to outer placement
# to detect misaligned shardings and properly replicate nested sharding first.
for mesh_dim in reversed(range(len(current_placements))):
current = current_placements[mesh_dim]
target = target_placements[mesh_dim]
# If target is not Shard, we can directly redistribute since we are traversing from innner
# to outer placements here
if isinstance(target, Shard):
# If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim
shard_dim = target.dim
current_mesh_sharding, target_mesh_sharding = [], []
for i, (s, p) in enumerate(zip(current_placements, target_placements)):
if i >= mesh_dim:
break
if s.is_shard(shard_dim):
current_mesh_sharding.append(i)
if p.is_shard(shard_dim):
target_mesh_sharding.append(i)
if current_mesh_sharding != target_mesh_sharding:
# if current/target_placements have misaligned sharding on the tensor dim BEFORE the current
# mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding
target = Replicate()
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
current_placements[mesh_dim] = target
# We always traverse from outer placement to inner placement to collect the remaining
# needed transform infos (i.e. the replication from nested sharding might need to further
# perform resharding to Shard again)
for mesh_dim, (current, target) in enumerate(
zip(current_placements, target_placements)
):
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
current_placements[mesh_dim] = target
return transform_infos
@cache
def _gen_transform_infos(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> list[_TransformInfo]:
return _gen_transform_infos_non_cached(src_spec, dst_spec)
def redistribute_local_tensor(
local_tensor: torch.Tensor,
current_spec: DTensorSpec,
target_spec: DTensorSpec,
*,
async_op: bool = False,
is_backward: bool = False,
) -> torch.Tensor:
"""
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
the target DTensorSpec, which involves the necessary collective calls to transform
the local shard of the DTensor from its current spec to the target spec.
"""
if current_spec.mesh != target_spec.mesh:
# TODO: alltoall/permute reshuffling to change device_mesh if they are not the same
raise NotImplementedError("Cross device mesh comm not supported yet!")
new_local_tensor = local_tensor
device_mesh = current_spec.mesh
my_coordinate = device_mesh.get_coordinate()
if my_coordinate is None:
# if rank is not part of mesh, we skip redistribute and simply return local_tensor,
# which should be an empty tensor
return local_tensor
if _are_we_tracing():
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
else:
transform_infos = _gen_transform_infos(current_spec, target_spec)
debug_mode = get_active_debug_mode()
redistribute_context = (
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
local_tensor,
current_spec.placements,
target_spec.placements,
DTensorRedistributePlanner.stringify_transform_infos(
device_mesh,
transform_infos,
current_spec.placements,
current_spec.shard_order,
),
)
if debug_mode is not None
else contextlib.nullcontext()
)
with redistribute_context:
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
num_chunks = device_mesh.size(mesh_dim=i)
if current == target:
# short cut, just use the original local tensor
new_local_tensor = local_tensor
continue
if num_chunks == 1:
# short cut, if there's only one shard, we don't need to do any collective
# comm, just use the original local tensor
new_local_tensor = local_tensor
continue
logger.debug(
"redistribute from %s to %s on mesh dim %s", current, target, i
)
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
# skip the replicate to partial transformation when we are in backward pass
# In this case we keep the grad as replicate, this is because we don't
# want to convert the replicated gradients back to partial, although
# that's logically conform with the same layout, converting the gradients
# back to partial is actually useless as you would have to do reduce later
# which would be more expensive than keeping it replicate! For this reason,
# we keep the replicate grad here.
new_local_tensor = (
partial_spec._partition_value(local_tensor, device_mesh, i)
if not is_backward
else local_tensor
)
elif current.is_shard():
if not is_backward:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
# for backward shard -> partial, we just need to convert the shard to replicate
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
if not async_op and isinstance(
new_local_tensor, funcol.AsyncCollectiveTensor
):
new_local_tensor = new_local_tensor.wait()
local_tensor = new_local_tensor
return new_local_tensor
class Redistribute(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
# pyre-fixme[2]: Parameter must be annotated.
ctx,
input: "dtensor.DTensor",
device_mesh: DeviceMesh,
placements: tuple[Placement, ...],
async_op: bool = False,
forward_dtype: Optional[torch.dtype] = None,
backward_dtype: Optional[torch.dtype] = None,
):
ctx.async_op = async_op
ctx.backward_dtype = backward_dtype
ctx.original_dtype = input._local_tensor.dtype
if forward_dtype is not None and forward_dtype != input._local_tensor.dtype:
local_tensor = input._local_tensor.to(dtype=forward_dtype)
current_spec = DTensorSpec(
mesh=device_mesh,
placements=input._spec.placements,
tensor_meta=TensorMeta(
shape=input.shape,
stride=input.stride(),
dtype=forward_dtype,
),
)
else:
local_tensor = input._local_tensor
current_spec = input._spec
ctx.current_spec = current_spec
if current_spec.placements != placements:
target_spec = DTensorSpec(
device_mesh, placements, tensor_meta=current_spec.tensor_meta
)
output = redistribute_local_tensor(
local_tensor, current_spec, target_spec, async_op=async_op
)
else:
# use the same local tensor if placements are the same.
output = local_tensor
target_spec = current_spec
return dtensor.DTensor(
output,
target_spec,
requires_grad=input.requires_grad,
)
@staticmethod
def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
previous_spec = ctx.current_spec
async_op = ctx.async_op
backward_dtype = ctx.backward_dtype or ctx.original_dtype
if backward_dtype != grad_output._local_tensor.dtype:
local_tensor = grad_output._local_tensor.to(dtype=backward_dtype)
current_spec = DTensorSpec(
mesh=grad_output._spec.device_mesh,
placements=grad_output._spec.placements,
tensor_meta=TensorMeta(
shape=grad_output.shape,
stride=grad_output.stride(),
dtype=backward_dtype,
),
)
previous_spec = DTensorSpec(
mesh=previous_spec.device_mesh,
placements=previous_spec.placements,
tensor_meta=current_spec.tensor_meta,
)
else:
local_tensor = grad_output._local_tensor
current_spec = grad_output._spec
output = redistribute_local_tensor(
local_tensor,
current_spec,
previous_spec,
async_op=async_op,
is_backward=True,
)
if output.dtype != ctx.original_dtype:
output = output.to(ctx.original_dtype)
# normalize the target placement to replicate if it is partial
normalized_placements: list[Placement] = []
for previous_placement in previous_spec.placements:
if previous_placement.is_partial():
# keep target placement to replicate instead of partial in this case
normalized_placements.append(Replicate())
else:
normalized_placements.append(previous_placement)
spec = DTensorSpec(
previous_spec.device_mesh,
tuple(normalized_placements),
tensor_meta=TensorMeta(
shape=grad_output.shape,
stride=grad_output.stride(),
dtype=output.dtype,
),
)
output_dtensor = dtensor.DTensor(
output,
spec,
requires_grad=grad_output.requires_grad,
)
return (
output_dtensor,
None,
None,
None,
None,
None,
)