mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[DeviceMesh] Implement a device mesh concatenate api for submesh and SPMD use case (#163358)
Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users. One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163358 Approved by: https://github.com/fegin ghstack dependencies: #166003
This commit is contained in:
parent
47f638eae7
commit
5a4997dcae
|
|
@ -1051,6 +1051,34 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
|||
)
|
||||
w.wait()
|
||||
|
||||
@with_comms
|
||||
def test_concatenate_2d(self):
|
||||
mesh_shape = (2, 4)
|
||||
mesh_dim_names = ("dp", "tp")
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
concatenated_mesh = DeviceMesh._concatenate([mesh_2d["dp"], mesh_2d["tp"]])
|
||||
self.assertEqual(concatenated_mesh.mesh, mesh_2d.mesh)
|
||||
self.assertEqual(concatenated_mesh.get_group("dp"), mesh_2d.get_group("dp"))
|
||||
self.assertEqual(concatenated_mesh.get_group("tp"), mesh_2d.get_group("tp"))
|
||||
|
||||
@with_comms
|
||||
def test_concatenate_3d(self):
|
||||
mesh_shape = (2, 2, 2)
|
||||
mesh_dim_names = ("pp", "dp", "tp")
|
||||
mesh_3d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
concatenated_mesh = DeviceMesh._concatenate([mesh_3d["dp"], mesh_3d["tp"]])
|
||||
dp_tp_mesh = mesh_3d["dp", "tp"]
|
||||
self.assertEqual(concatenated_mesh.mesh, dp_tp_mesh.mesh)
|
||||
self.assertEqual(concatenated_mesh.get_group("dp"), dp_tp_mesh.get_group("dp"))
|
||||
self.assertEqual(concatenated_mesh.get_group("tp"), dp_tp_mesh.get_group("tp"))
|
||||
self.assertEqual(
|
||||
mesh_3d, DeviceMesh._concatenate([mesh_3d["pp", "dp"], mesh_3d["tp"]])
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_reconstruct_mesh_with_flatten_dim(self):
|
||||
mesh_3d = init_device_mesh(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Optional, TYPE_CHECKING, Union
|
|||
import torch
|
||||
from torch.distributed import is_available
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed._pycute import is_int, suffix_product
|
||||
from torch.distributed._pycute import IntTuple, is_int, suffix_product
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
|
|
@ -1176,6 +1176,41 @@ else:
|
|||
backend_override_tuple,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _concatenate(device_mesh_list: list["DeviceMesh"]) -> "DeviceMesh":
|
||||
concat_dim_names: list[str] = []
|
||||
concat_sizes: list[IntTuple] = []
|
||||
concat_strides: list[IntTuple] = []
|
||||
concat_dim_group_name: list[str] = []
|
||||
flatten_rank_map = device_mesh_list[0]._flatten_rank_map
|
||||
for dm in device_mesh_list:
|
||||
for i in range(len(dm._layout)):
|
||||
concat_sizes.append(dm._layout[i].sizes)
|
||||
concat_strides.append(dm._layout[i].strides)
|
||||
concat_dim_names.extend(not_none(dm.mesh_dim_names))
|
||||
concat_dim_group_name.extend(not_none(dm._dim_group_names))
|
||||
# Concatenate device mesh having different root mesh tensors are meaningless
|
||||
# because the concatenated indices should be indexed by the same root mesh tensor.
|
||||
if dm._flatten_rank_map != flatten_rank_map:
|
||||
raise RuntimeError(
|
||||
"Cannot concatenate DeviceMeshes derived from different device meshs"
|
||||
)
|
||||
concat_mesh_layout = _MeshLayout(tuple(concat_sizes), tuple(concat_strides))
|
||||
if not concat_mesh_layout.check_non_overlap():
|
||||
raise RuntimeError(
|
||||
f"Cannot concatenate overlapping meshes: {device_mesh_list}"
|
||||
)
|
||||
res_mesh = DeviceMesh(
|
||||
device_mesh_list[0].device_type,
|
||||
_layout=concat_mesh_layout,
|
||||
_rank_map=device_mesh_list[0]._rank_map,
|
||||
mesh_dim_names=tuple(concat_dim_names),
|
||||
_root_mesh=device_mesh_list[0]._get_root_mesh(),
|
||||
_init_backend=False,
|
||||
)
|
||||
res_mesh._dim_group_names = concat_dim_group_name
|
||||
return res_mesh
|
||||
|
||||
def _normalize_backend_override(
|
||||
backend_override: dict[
|
||||
Union[int, str],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user