mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #159590 This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue #159601. This adds little to no overhead in eager ([see past benchmarks](https://github.com/pytorch/pytorch/pull/156868#issuecomment-3047831149)). This also handles cases such as #159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160798 Approved by: https://github.com/bdhirsh
402 lines
16 KiB
Python
402 lines
16 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import logging
|
|
from functools import cache
|
|
from typing import cast, NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch.distributed._functional_collectives as funcol
|
|
import torch.distributed.tensor._api as dtensor
|
|
from torch.distributed._functional_collectives import _are_we_tracing
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
|
from torch.distributed.tensor.device_mesh import DeviceMesh
|
|
from torch.distributed.tensor.placement_types import (
|
|
Partial,
|
|
Placement,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class _TransformInfo(NamedTuple):
|
|
mesh_dim: int
|
|
src_dst_placements: tuple[Placement, Placement]
|
|
# logical_shape on this mesh dimension
|
|
logical_shape: list[int]
|
|
|
|
|
|
def _gen_transform_infos_non_cached(
|
|
src_spec: DTensorSpec,
|
|
dst_spec: DTensorSpec,
|
|
) -> list[_TransformInfo]:
|
|
"""
|
|
Generate the transform infos from the source placements to the target placements.
|
|
|
|
To transform from source to target placement it might have multiple steps, i.e. it
|
|
might decompose Si -> Sj into Si -> R -> Sj.
|
|
This would detect if there're mis-aligned/nested shardings between src/dst placements.
|
|
E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)),
|
|
in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in
|
|
the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
|
|
the latter is the first sharding on tensor dimension 0.
|
|
"""
|
|
transform_infos: list[_TransformInfo] = []
|
|
|
|
device_mesh = src_spec.device_mesh
|
|
my_coordinate = device_mesh.get_coordinate()
|
|
assert my_coordinate is not None
|
|
|
|
# logical shape records the logic tensor shape on the mesh dimension
|
|
# this is useful to ensure uneven sharding gets correct output shape
|
|
initial_logical_shape = list(src_spec.shape)
|
|
mesh_dims_to_logical_shape = [initial_logical_shape]
|
|
|
|
if device_mesh.ndim == 1:
|
|
# if device_mesh is 1D, redistribute is a simple direct transformation
|
|
transform_infos.append(
|
|
_TransformInfo(
|
|
mesh_dim=0,
|
|
src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]),
|
|
logical_shape=initial_logical_shape,
|
|
)
|
|
)
|
|
return transform_infos
|
|
|
|
# Handle multi-dim device mesh placement redistribution
|
|
# First, we need to build the logical shape for each mesh dim
|
|
# for correct allgathering uneven shards on each mesh dim (with dynamic padding)
|
|
for i, src in enumerate(src_spec.placements):
|
|
current_logical_shape = mesh_dims_to_logical_shape[i]
|
|
if isinstance(src, Shard):
|
|
if i < device_mesh.ndim - 1:
|
|
# calculate and save the logical shape for this sharding
|
|
mesh_dim_size = device_mesh.size(mesh_dim=i)
|
|
local_shard_size, _ = src._local_shard_size_and_offset(
|
|
current_logical_shape[src.dim],
|
|
mesh_dim_size,
|
|
my_coordinate[i],
|
|
)
|
|
new_logical_shape = list(current_logical_shape)
|
|
new_logical_shape[src.dim] = local_shard_size
|
|
mesh_dims_to_logical_shape.append(new_logical_shape)
|
|
else:
|
|
mesh_dims_to_logical_shape.append(current_logical_shape)
|
|
|
|
# Next, we need to derive the transform infos from src to dst placements,
|
|
# here we use a greedy search with step by step state transformations
|
|
current_placements = list(src_spec.placements)
|
|
target_placements = list(dst_spec.placements)
|
|
|
|
if src_spec.num_shards > 1:
|
|
# If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec
|
|
# a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))).
|
|
# In those cases, we first traverse from inner placement to outer placement
|
|
# to detect misaligned shardings and properly replicate nested sharding first.
|
|
for mesh_dim in reversed(range(len(current_placements))):
|
|
current = current_placements[mesh_dim]
|
|
target = target_placements[mesh_dim]
|
|
# If target is not Shard, we can directly redistribute since we are traversing from innner
|
|
# to outer placements here
|
|
if isinstance(target, Shard):
|
|
# If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim
|
|
shard_dim = target.dim
|
|
current_mesh_sharding, target_mesh_sharding = [], []
|
|
for i, (s, p) in enumerate(zip(current_placements, target_placements)):
|
|
if i >= mesh_dim:
|
|
break
|
|
if s.is_shard(shard_dim):
|
|
current_mesh_sharding.append(i)
|
|
if p.is_shard(shard_dim):
|
|
target_mesh_sharding.append(i)
|
|
|
|
if current_mesh_sharding != target_mesh_sharding:
|
|
# if current/target_placements have misaligned sharding on the tensor dim BEFORE the current
|
|
# mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding
|
|
target = Replicate()
|
|
|
|
if current != target:
|
|
transform_infos.append(
|
|
_TransformInfo(
|
|
mesh_dim=mesh_dim,
|
|
src_dst_placements=(current, target),
|
|
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
|
|
)
|
|
)
|
|
current_placements[mesh_dim] = target
|
|
|
|
# We always traverse from outer placement to inner placement to collect the remaining
|
|
# needed transform infos (i.e. the replication from nested sharding might need to further
|
|
# perform resharding to Shard again)
|
|
for mesh_dim, (current, target) in enumerate(
|
|
zip(current_placements, target_placements)
|
|
):
|
|
if current != target:
|
|
transform_infos.append(
|
|
_TransformInfo(
|
|
mesh_dim=mesh_dim,
|
|
src_dst_placements=(current, target),
|
|
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
|
|
)
|
|
)
|
|
current_placements[mesh_dim] = target
|
|
|
|
return transform_infos
|
|
|
|
|
|
@cache
|
|
def _gen_transform_infos(
|
|
src_spec: DTensorSpec,
|
|
dst_spec: DTensorSpec,
|
|
) -> list[_TransformInfo]:
|
|
return _gen_transform_infos_non_cached(src_spec, dst_spec)
|
|
|
|
|
|
def redistribute_local_tensor(
|
|
local_tensor: torch.Tensor,
|
|
current_spec: DTensorSpec,
|
|
target_spec: DTensorSpec,
|
|
*,
|
|
async_op: bool = False,
|
|
is_backward: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
|
|
the target DTensorSpec, which involves the necessary collective calls to transform
|
|
the local shard of the DTensor from its current spec to the target spec.
|
|
"""
|
|
|
|
if current_spec.mesh != target_spec.mesh:
|
|
# TODO: alltoall/permute reshuffling to change device_mesh if they are not the same
|
|
raise NotImplementedError("Cross device mesh comm not supported yet!")
|
|
|
|
new_local_tensor = local_tensor
|
|
device_mesh = current_spec.mesh
|
|
|
|
my_coordinate = device_mesh.get_coordinate()
|
|
|
|
if my_coordinate is None:
|
|
# if rank is not part of mesh, we skip redistribute and simply return local_tensor,
|
|
# which should be an empty tensor
|
|
return local_tensor
|
|
|
|
if _are_we_tracing():
|
|
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
|
|
else:
|
|
transform_infos = _gen_transform_infos(current_spec, target_spec)
|
|
|
|
for transform_info in transform_infos:
|
|
i = transform_info.mesh_dim
|
|
current, target = transform_info.src_dst_placements
|
|
device_mesh.size(mesh_dim=i)
|
|
|
|
if current == target:
|
|
# short cut, just use the original local tensor
|
|
new_local_tensor = local_tensor
|
|
continue
|
|
|
|
logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
|
|
|
|
if target.is_replicate():
|
|
# Case 1: target is Replicate
|
|
if current.is_partial():
|
|
partial_spec = cast(Partial, current)
|
|
new_local_tensor = partial_spec._reduce_value(
|
|
local_tensor, device_mesh, i
|
|
)
|
|
elif current.is_shard():
|
|
current_placement = cast(Shard, current)
|
|
new_local_tensor = current_placement._to_replicate_tensor(
|
|
local_tensor, device_mesh, i, transform_info.logical_shape
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"redistribute from {current} to {target} not supported yet"
|
|
)
|
|
elif target.is_shard():
|
|
# Case 2: target is Shard
|
|
target_placement = cast(Shard, target)
|
|
if current.is_partial():
|
|
partial_spec = cast(Partial, current)
|
|
new_local_tensor = partial_spec._reduce_shard_value(
|
|
local_tensor, device_mesh, i, target_placement
|
|
)
|
|
elif current.is_replicate():
|
|
# split the tensor and return the corresponding cloned local shard
|
|
new_local_tensor = target_placement._replicate_to_shard(
|
|
local_tensor, device_mesh, i, my_coordinate[i]
|
|
)
|
|
else:
|
|
assert current.is_shard(), (
|
|
f"Current placement should be shard but found {current}"
|
|
)
|
|
shard_spec = cast(Shard, current)
|
|
if shard_spec.dim != target_placement.dim:
|
|
new_local_tensor = shard_spec._to_new_shard_dim(
|
|
local_tensor,
|
|
device_mesh,
|
|
i,
|
|
transform_info.logical_shape,
|
|
target_placement.dim,
|
|
)
|
|
elif target.is_partial():
|
|
if current.is_replicate():
|
|
partial_spec = cast(Partial, target)
|
|
# skip the replicate to partial transformation when we are in backward pass
|
|
# In this case we keep the grad as replicate, this is because we don't
|
|
# want to convert the replicated gradients back to partial, although
|
|
# that's logically conform with the same layout, converting the gradients
|
|
# back to partial is actually useless as you would have to do reduce later
|
|
# which would be more expensive than keeping it replicate! For this reason,
|
|
# we keep the replicate grad here.
|
|
new_local_tensor = (
|
|
partial_spec._partition_value(local_tensor, device_mesh, i)
|
|
if not is_backward
|
|
else local_tensor
|
|
)
|
|
elif current.is_shard():
|
|
if not is_backward:
|
|
raise RuntimeError(
|
|
f"redistribute from {current} to {target} not supported yet"
|
|
)
|
|
# for backward shard -> partial, we just need to convert the shard to replicate
|
|
current_placement = cast(Shard, current)
|
|
new_local_tensor = current_placement._to_replicate_tensor(
|
|
local_tensor, device_mesh, i, transform_info.logical_shape
|
|
)
|
|
else:
|
|
# partial -> partial no op, should never hit
|
|
new_local_tensor = local_tensor
|
|
|
|
local_tensor = new_local_tensor
|
|
|
|
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
|
|
new_local_tensor = new_local_tensor.wait()
|
|
|
|
return new_local_tensor
|
|
|
|
|
|
class Redistribute(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward( # type: ignore[override]
|
|
# pyre-fixme[2]: Parameter must be annotated.
|
|
ctx,
|
|
input: "dtensor.DTensor",
|
|
device_mesh: DeviceMesh,
|
|
placements: tuple[Placement, ...],
|
|
async_op: bool = False,
|
|
forward_dtype: Optional[torch.dtype] = None,
|
|
backward_dtype: Optional[torch.dtype] = None,
|
|
):
|
|
ctx.async_op = async_op
|
|
ctx.backward_dtype = backward_dtype
|
|
ctx.original_dtype = input._local_tensor.dtype
|
|
|
|
if forward_dtype is not None and forward_dtype != input._local_tensor.dtype:
|
|
local_tensor = input._local_tensor.to(dtype=forward_dtype)
|
|
current_spec = DTensorSpec(
|
|
mesh=device_mesh,
|
|
placements=input._spec.placements,
|
|
tensor_meta=TensorMeta(
|
|
shape=input.shape,
|
|
stride=input.stride(),
|
|
dtype=forward_dtype,
|
|
),
|
|
)
|
|
else:
|
|
local_tensor = input._local_tensor
|
|
current_spec = input._spec
|
|
|
|
ctx.current_spec = current_spec
|
|
|
|
if current_spec.placements != placements:
|
|
target_spec = DTensorSpec(
|
|
device_mesh, placements, tensor_meta=current_spec.tensor_meta
|
|
)
|
|
|
|
output = redistribute_local_tensor(
|
|
local_tensor, current_spec, target_spec, async_op=async_op
|
|
)
|
|
else:
|
|
# use the same local tensor if placements are the same.
|
|
output = local_tensor
|
|
target_spec = current_spec
|
|
|
|
return dtensor.DTensor(
|
|
output,
|
|
target_spec,
|
|
requires_grad=input.requires_grad,
|
|
)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
|
|
previous_spec = ctx.current_spec
|
|
async_op = ctx.async_op
|
|
backward_dtype = ctx.backward_dtype or ctx.original_dtype
|
|
|
|
if backward_dtype != grad_output._local_tensor.dtype:
|
|
local_tensor = grad_output._local_tensor.to(dtype=backward_dtype)
|
|
current_spec = DTensorSpec(
|
|
mesh=grad_output._spec.device_mesh,
|
|
placements=grad_output._spec.placements,
|
|
tensor_meta=TensorMeta(
|
|
shape=grad_output.shape,
|
|
stride=grad_output.stride(),
|
|
dtype=backward_dtype,
|
|
),
|
|
)
|
|
previous_spec = DTensorSpec(
|
|
mesh=previous_spec.device_mesh,
|
|
placements=previous_spec.placements,
|
|
tensor_meta=current_spec.tensor_meta,
|
|
)
|
|
else:
|
|
local_tensor = grad_output._local_tensor
|
|
current_spec = grad_output._spec
|
|
|
|
output = redistribute_local_tensor(
|
|
local_tensor,
|
|
current_spec,
|
|
previous_spec,
|
|
async_op=async_op,
|
|
is_backward=True,
|
|
)
|
|
|
|
if output.dtype != ctx.original_dtype:
|
|
output = output.to(ctx.original_dtype)
|
|
|
|
# normalize the target placement to replicate if it is partial
|
|
normalized_placements: list[Placement] = []
|
|
for previous_placement in previous_spec.placements:
|
|
if previous_placement.is_partial():
|
|
# keep target placement to replicate instead of partial in this case
|
|
normalized_placements.append(Replicate())
|
|
else:
|
|
normalized_placements.append(previous_placement)
|
|
|
|
spec = DTensorSpec(
|
|
previous_spec.device_mesh,
|
|
tuple(normalized_placements),
|
|
tensor_meta=TensorMeta(
|
|
shape=grad_output.shape,
|
|
stride=grad_output.stride(),
|
|
dtype=output.dtype,
|
|
),
|
|
)
|
|
output_dtensor = dtensor.DTensor(
|
|
output,
|
|
spec,
|
|
requires_grad=grad_output.requires_grad,
|
|
)
|
|
|
|
return (
|
|
output_dtensor,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|