[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)

Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163358
Approved by: https://github.com/fegin
ghstack dependencies: #166003
This commit is contained in:
fduwjj 2025-10-23 14:07:19 -07:00 committed by PyTorch MergeBot
parent 47f638eae7
commit 5a4997dcae
2 changed files with 64 additions and 1 deletions

View File

@ -1051,6 +1051,34 @@ class TestDeviceMeshGetItem(DTensorTestBase):
)
w.wait()
@with_comms
def test_concatenate_2d(self):
mesh_shape = (2, 4)
mesh_dim_names = ("dp", "tp")
mesh_2d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
concatenated_mesh = DeviceMesh._concatenate([mesh_2d["dp"], mesh_2d["tp"]])
self.assertEqual(concatenated_mesh.mesh, mesh_2d.mesh)
self.assertEqual(concatenated_mesh.get_group("dp"), mesh_2d.get_group("dp"))
self.assertEqual(concatenated_mesh.get_group("tp"), mesh_2d.get_group("tp"))
@with_comms
def test_concatenate_3d(self):
mesh_shape = (2, 2, 2)
mesh_dim_names = ("pp", "dp", "tp")
mesh_3d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
concatenated_mesh = DeviceMesh._concatenate([mesh_3d["dp"], mesh_3d["tp"]])
dp_tp_mesh = mesh_3d["dp", "tp"]
self.assertEqual(concatenated_mesh.mesh, dp_tp_mesh.mesh)
self.assertEqual(concatenated_mesh.get_group("dp"), dp_tp_mesh.get_group("dp"))
self.assertEqual(concatenated_mesh.get_group("tp"), dp_tp_mesh.get_group("tp"))
self.assertEqual(
mesh_3d, DeviceMesh._concatenate([mesh_3d["pp", "dp"], mesh_3d["tp"]])
)
@with_comms
def test_reconstruct_mesh_with_flatten_dim(self):
mesh_3d = init_device_mesh(

View File

@ -11,7 +11,7 @@ from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.distributed._mesh_layout import _MeshLayout
from torch.distributed._pycute import is_int, suffix_product
from torch.distributed._pycute import IntTuple, is_int, suffix_product
from torch.utils._typing_utils import not_none
@ -1176,6 +1176,41 @@ else:
backend_override_tuple,
)
@staticmethod
def _concatenate(device_mesh_list: list["DeviceMesh"]) -> "DeviceMesh":
concat_dim_names: list[str] = []
concat_sizes: list[IntTuple] = []
concat_strides: list[IntTuple] = []
concat_dim_group_name: list[str] = []
flatten_rank_map = device_mesh_list[0]._flatten_rank_map
for dm in device_mesh_list:
for i in range(len(dm._layout)):
concat_sizes.append(dm._layout[i].sizes)
concat_strides.append(dm._layout[i].strides)
concat_dim_names.extend(not_none(dm.mesh_dim_names))
concat_dim_group_name.extend(not_none(dm._dim_group_names))
# Concatenate device mesh having different root mesh tensors are meaningless
# because the concatenated indices should be indexed by the same root mesh tensor.
if dm._flatten_rank_map != flatten_rank_map:
raise RuntimeError(
"Cannot concatenate DeviceMeshes derived from different device meshs"
)
concat_mesh_layout = _MeshLayout(tuple(concat_sizes), tuple(concat_strides))
if not concat_mesh_layout.check_non_overlap():
raise RuntimeError(
f"Cannot concatenate overlapping meshes: {device_mesh_list}"
)
res_mesh = DeviceMesh(
device_mesh_list[0].device_type,
_layout=concat_mesh_layout,
_rank_map=device_mesh_list[0]._rank_map,
mesh_dim_names=tuple(concat_dim_names),
_root_mesh=device_mesh_list[0]._get_root_mesh(),
_init_backend=False,
)
res_mesh._dim_group_names = concat_dim_group_name
return res_mesh
def _normalize_backend_override(
backend_override: dict[
Union[int, str],