pytorch/torch/distributed/tensor/_utils.py
Scott Wolchok 6a5a436624 DTensor: C++ compute_global_tensor_info (#162990)
compute_global_tensor_info is on the hot path for DTensor.{from,to}_local. More incremental progress toward C++.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162990
Approved by: https://github.com/ezyang
2025-10-30 15:10:54 +00:00

412 lines
17 KiB
Python

from collections import defaultdict
from collections.abc import Sequence
from typing import cast, Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._api as dtensor
from torch._prims_common import ShapeType
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.placement_types import (
_StridedShard,
Partial,
Placement,
Replicate,
Shard,
)
from torch.utils._typing_utils import not_none
def _explicit_order_placements(
mesh_shape: ShapeType, placements: Sequence[Placement]
) -> Sequence[tuple[int, Placement]]:
"""
Replace Strided Shards with regular shards in an adjusted order.
Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order.
ex.
[Shard(0), _StridedShard(0, split_factor=2), Shard(0)] ->
[(0, Shard(0)), (2, Shard(0)), (1, Shard(0))]
"""
if not len(placements) == len(mesh_shape):
raise RuntimeError(
"Expected one placement per mesh dim, "
f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims."
)
ordered = []
deferred_strided_placements = defaultdict(list)
strided_part_ended_for_dim = set()
for mesh_dim, p in enumerate(placements):
if isinstance(p, _StridedShard):
# validate the stride is the correct multiple of the meshdim and the earlier shard
deferred_strided_placements[p.dim].append((mesh_dim, p))
else:
ordered.append((mesh_dim, p))
if isinstance(p, Shard):
if p.dim in strided_part_ended_for_dim:
raise NotImplementedError(
f"Strided sharding does not allow Shard() to appear after "
f"the strided part has ended. {p} at mesh dim {mesh_dim} in "
f"{placements} violates this assumption."
)
if p.dim in deferred_strided_placements:
strided_part_ended_for_dim.add(p.dim)
strided_placements = deferred_strided_placements.pop(p.dim)
aggregate_size = mesh_shape[mesh_dim]
while len(strided_placements) > 0:
strided_mesh_dim, strided = strided_placements.pop()
if not strided.split_factor == aggregate_size:
raise RuntimeError(
f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})"
f" == aggregate mesh size ({aggregate_size})"
)
aggregate_size *= mesh_shape[strided_mesh_dim]
ordered.append((strided_mesh_dim, Shard(p.dim)))
return ordered
def compute_local_shape_and_global_offset(
global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""
Compute the local tensor shape and the global offsets into the original tensor
of a DTensor on its current global rank. This is useful for checkpointing purpose.
Example:
global_tensor = [[0, 1, 2, 3, 4], sharded on mesh (DP=2, TP=2) with (Shard(1), Shard(1))
[10, 11, 12, 13, 14]]
This table shows the return value of local_shape and global_offset for each rank.
(`local_tensor` is for illustration only).
Note how the first coordinate of global_offset is always 0, corresponding to tensor dim 0 being replicated.
Rank local_tensor local_shape global_offset
-------------------------------------------------------------
0 [[0, 1], (2, 2) (0, 0)
[10, 11]]
1 [[2], (2, 1) (0, 2)
[12]]
2 [[3], (2, 1) (0, 3)
[13]]
3 [[4], (2, 1) (0, 4)
[14]]
Args:
global_shape (ShapeType): The global shape of the DTensor.
mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on.
placements (Sequence[:class:`Placement`]]): The placements of the DTensor.
Return:
local_shape: the shape of the DTensor's _local_tensor on the current rank.
global_offset: a tuple of offsets for each dimension of the global tensor shape,
identifying how this shard fits into the global tensor in each dimension.
"""
return _compute_local_shape_and_global_offset(
global_shape, mesh.shape, mesh.get_coordinate(), placements
)
# accept 'plain data types' to enable simpler unit testing without creating device mesh
def _compute_local_shape_and_global_offset(
global_shape: ShapeType,
mesh_shape: ShapeType,
my_coordinate: Optional[list[int]],
placements: Sequence[Placement],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""
Suppose you have a full tensor with size global_shape, and you have sharded
it according to placements for mesh_shape. This function returns, for a
specific coordinate my_coordinate in the device mesh:
- The size of your local shard WITHOUT padding (i.e., if you have
an uneven split, your size might be smaller than the other entries
in your dim), and
- Where the data for your shard begins, in the full tensor.
This function is fairly simple if your tensor is evenly sharded; the complication
is around uneven splits. There is also some complication for handling StridedShard,
which changes the order you should apply sharding.
"""
if my_coordinate is None:
# if rank not in the mesh, return empty offset
return ((0,), ())
# StridedShard implies a non-standard order to apply shards; get the
# correct order to start applying splits
ordered_placements = _explicit_order_placements(mesh_shape, placements)
local_shape = list(global_shape)
# We'll compute the data for where the shard begins on a per-dim basis.
# However, a single dim can be sharded multiple times, so we will end up
# doing a Sum(size*stride) like computation to determine the location of our
# shard for each of the shardings on that dim.
global_offset = [0] * len(global_shape)
for mesh_dim, placement in ordered_placements:
mesh_dim_size = mesh_shape[mesh_dim]
if isinstance(placement, Shard):
shard_dim = placement.dim
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_and_offset(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[mesh_dim],
)
local_shape[shard_dim] = shard_size
shard_global_offset = global_offset[shard_dim] + not_none(shard_offset)
zero_global_offset = global_shape[shard_dim]
if isinstance(shard_global_offset, torch.SymInt) and not isinstance(
zero_global_offset, torch.SymInt
):
zero_global_offset = torch.SymInt(zero_global_offset)
global_offset[shard_dim] = torch.sym_ite(
shard_size == 0,
# Special case to fill in a standardized non-garbage value for
# the global_offset of zero-sized shards. This value is out
# of bounds of the tensor, so it won't conflict with any real
# offsets. DCP may rely on this value to de-duplicate shards.
# Note that you can end up with zero-size shards that are
# still otherwise in bounds for the tensor (TODO: give an
# example).
zero_global_offset,
# As we successively shard the same dimension, we keep
# advancing our pointer beyond our original offset until we
# get to the final chunk start.
shard_global_offset,
)
# NOTE: the offset compute relies on the local shard index and it has no
# problem when strided sharding is not present. To correctly compute, we assume
# that the ``_StridedShard.split_factor`` field encodes how many partitions
# each local tensor will be further split into when sharding on higher mesh
# dimensions. However, this number is only correct if the DTensor is not
# sharded after the strided sharding completes. For example,
# [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
# where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
# device mesh dim-2, and last on mesh dim-1. We define the
# "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
# part because strided sharding happens on mesh dim-1 and it was caused by
# the fact that sharding on dim-2 occurred ahead. In this case, there's no
# further sharding after this strided sharding part and ``split_factor``
# correctly encodes the number. Another example is
# [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
# dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
# dim-2. This violates our assumption that no further sharding shall occur
# after the strided sharding part and ``split_factor`` won't correctly
# encode the number of further split. So far, the only case where _StridedShard
# placement would appear is FSDP2 + TP on 2D mesh and the above case could only
# happen on mesh of 3 or more dimensions.
# TODO: change this function to correctly address this.
# TODO: this logic can be applied to contiguous sharding as well
return tuple(local_shape), tuple(global_offset)
compute_global_tensor_info = torch._C._DTensor_compute_global_tensor_info
def compute_local_tensor_info(
global_tensor: torch.Tensor,
mesh: DeviceMesh,
placements: Sequence[Placement],
) -> tuple[list[int], list[int]]:
"""
Compute the local size and stride of a DTensor from the given global tensor info.
For example, if we have a global tensor with size (4, 8, 4) and stride (32, 1, 8).
If the DTensor placements are [Shard(2)] and world_size is 2;
then the local size is (4, 8, 2) and stride is (16, 1, 8).
Args:
tensor (:class:`torch.Tensor`):
Global tensor which DTensor will distribute
mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
placements (Sequence[:class:`Placement`]):
The attribute of the DTensor that describes its layout
on the mesh topology.
Returns:
local_shape: A List of int which specifies the size of the local tensor.
local_stride: A List of int which specifies the stride of the local tensor.
"""
local_shape = list(global_tensor.size())
local_stride = list(global_tensor.stride())
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if placement.is_shard():
shard_placement = cast(Shard, placement)
if shard_placement.dim < 0:
raise AssertionError(
"Shard placements should have negative dims normalized in "
f"the user-facing APIs: {shard_placement}"
)
shard_dim = shard_placement.dim
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)} "
f"for placement number {idx}."
)
global_dim_size = local_shape[shard_dim]
assert global_dim_size % mesh_dim_size == 0, (
f"Global dim {global_dim_size} not divisible by mesh size {mesh_dim_size}"
)
local_shape[shard_dim] = global_dim_size // mesh_dim_size
# shrink strides that were scaled up globally
for i in range(len(local_stride)):
if (
i != shard_dim
and local_stride[i] >= local_stride[shard_dim] * mesh_dim_size
):
local_stride[i] = local_stride[i] // mesh_dim_size
elif not isinstance(placement, (Replicate, Partial)):
raise RuntimeError(f"placement type {type(placement)} not supported!")
return local_shape, local_stride
def compute_global_tensor_shape(
shape: torch.Size, mesh: DeviceMesh, placements: Sequence[Placement]
) -> torch.Size:
"""
Compute the global size of a DTensor from the given local tensor shape,
the mesh and placements. Different from `compute_global_tensor_info`,
which assumes sharding is even, this util allgathers local shards' shapes
from all ranks and thus can support uneven sharding.
NOTE: Currently this function only supports 1D mesh.
Args:
shape (:class:`torch.Size`):
Shape of the local tensor
mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
placements (Sequence[:class:`Placement`]]):
The attribute of the DTensor that describes its layout
on the mesh topology.
Return:
tensor_shape: Shape of the global DTensor.
"""
if len(placements) != 1:
raise NotImplementedError(
"compute_global_tensor_shape only supports 1 placement for now."
)
if len(placements) != mesh.ndim:
raise RuntimeError(
"Expected one placement per mesh dim, "
f"but found {len(placements)} placements and {mesh.ndim} mesh dims."
)
if isinstance(placements[0], Replicate):
return shape
elif isinstance(placements[0], Shard):
local_shape = torch.tensor(list(shape), device=mesh.device_type)
gathered_shaped_tensors = [
torch.empty_like(local_shape, device=local_shape.device)
for _ in range(mesh.size())
]
funcol.all_gather_inplace(gathered_shaped_tensors, local_shape, mesh)
sharded_dim_sum = 0
shard_dim = placements[0].dim
other_dims = [d for d in range(mesh.ndim) if d != shard_dim]
for shape_tensor in gathered_shaped_tensors:
if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]):
raise RuntimeError(
"Non-sharded dimensions should have identical size across ranks."
)
shape_tensor_list = shape_tensor.tolist()
sharded_dim_sum += shape_tensor_list[shard_dim]
global_shape = list(shape)
global_shape[placements[0].dim] = sharded_dim_sum
return torch.Size(global_shape)
else:
raise NotImplementedError(
f"Placement type {type(placements[0])} not supported."
)
def try_find_mesh_from_args(
op_call: torch._ops.OpOverload, args: Sequence[object]
) -> DeviceMesh:
"""
Find the device mesh object from args.
It returns None if no mesh is found.
NOTE: we can optimize this search if needed
"""
for arg in args:
if isinstance(arg, (dtensor.DTensor, DTensorSpec)):
return arg.device_mesh
elif (
isinstance(arg, (list, tuple))
and len(arg) > 0
and isinstance(arg[0], (dtensor.DTensor, DTensorSpec))
):
return arg[0].device_mesh
raise ValueError(f"Cannot find device mesh from args for op : {op_call}.")
def compute_local_stride(
global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> tuple[int, ...]:
"""
Compute the stride of a local tensor shard, given the global stride of the DTensor.
NOTE: Currently this function is assuming the DTensor is evenly shardable.
"""
stride_divisors = [1] * len(global_stride)
for mesh_idx, p in enumerate(placements):
if p.is_shard():
i = cast(Shard, p).dim
# tensor dimension i is sharded on mesh dimension mesh_idx,
# so we need to divide all the strides larger than stride[i]
# (by the submesh size)
for j in range(len(global_stride)):
if global_stride[j] > global_stride[i]:
stride_divisors[j] *= mesh.size(mesh_idx)
return tuple(
global_stride[i] // stride_divisors[i] for i in range(len(global_stride))
)
def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def]
"""
Unify variable types of size argument to torch.Size
Acceptable types include:
int, Sequence[int], Tuple[int], Tuple[Sequence[int]],
or torch.Size
"""
if isinstance(size, torch.Size):
return size
if isinstance(size, int):
torch_size = [size]
elif len(size) == 1 and isinstance(size[0], Sequence):
torch_size = list(size[0])
else:
torch_size = list(size)
return torch.Size(torch_size)