[device_mesh] replace dim_group_info with group_name (#150898)

as titled, there's no need to maintain a dim_group_info anymore, we can
simply maintain a list of group_name instead. This will simplify the
logic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150898
Approved by: https://github.com/tianyu-l, https://github.com/fegin
This commit is contained in:
Wanchao Liang 2025-05-13 06:01:22 +00:00 committed by PyTorch MergeBot
parent 9c3cef437c
commit 9df9d9ded0
4 changed files with 55 additions and 88 deletions

View File

@ -840,7 +840,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
# since the ref has a parent mesh, while the `from_group` one does not
self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh)
self.assertEqual(dp_mesh._coordinate_on_dim, ref_dp_mesh._coordinate_on_dim)
self.assertEqual(dp_mesh._dim_group_infos, ref_dp_mesh._dim_group_infos)
self.assertEqual(dp_mesh._dim_group_names, ref_dp_mesh._dim_group_names)
# Check 1D FSDP forward/backward parity over the DP mesh
# NOTE: We cannot use 2D DTensor-based training here because the DP
@ -916,12 +916,6 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
)
self.assertEqual(mesh.mesh, ref_mesh.mesh)
self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim)
for (_, ranks, _), (_, ref_ranks, _) in zip(
mesh._dim_group_infos, ref_mesh._dim_group_infos
):
# Since we manually constructed new subgroups, the test and ref
# groups are not the same
self.assertEqual(ranks, ref_ranks)
for mesh_dim_name in mesh_dim_names:
child_mesh = mesh[mesh_dim_name]
ref_child_mesh = ref_mesh[mesh_dim_name]

View File

@ -3,6 +3,7 @@
import os
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
@ -197,7 +198,7 @@ class DeviceMeshTest(DTensorTestBase):
local_tensor = torch.randn(2, 8)
global_tensor = funcol.all_gather_tensor(
local_tensor, gather_dim=0, group=(mesh, 0)
)
).wait()
self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))
@with_comms
@ -208,7 +209,7 @@ class DeviceMeshTest(DTensorTestBase):
mesh_pg = ref_global_mesh.get_group()
global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type)
self.assertEqual(ref_global_mesh, global_mesh)
self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
self.assertEqual(
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
)
@ -217,7 +218,7 @@ class DeviceMeshTest(DTensorTestBase):
mesh_pg, self.device_type, mesh=torch.arange(self.world_size)
)
self.assertEqual(ref_global_mesh, global_mesh)
self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
self.assertEqual(
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
)
@ -396,24 +397,20 @@ class DeviceMeshTestNDim(DTensorTestBase):
mesh_dim_names=("dp_replicate", "dp_shard"),
)
ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
for (_, ref_ranks, _), (_, ranks, _) in zip(
ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
):
self.assertEqual(ref_ranks, ranks)
ref_mesh_dp_dim_group_names = ref_mesh._dim_group_names[:2]
self.assertEqual(ref_mesh_dp_dim_group_names, dp_mesh._dim_group_names[:2])
# Cannot check directly for mesh equality since parent meshes are not
# the same since the ref's parent mesh is 3D
self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh)
for (_, ref_ranks, _), (_, ranks, _) in zip(
dp_mesh["dp_replicate"]._dim_group_infos,
ref_mesh["dp_replicate"]._dim_group_infos,
):
self.assertEqual(ref_ranks, ranks)
self.assertEqual(
dp_mesh["dp_replicate"]._dim_group_names,
ref_mesh["dp_replicate"]._dim_group_names,
)
self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh)
for (_, ref_ranks, _), (_, ranks, _) in zip(
dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos
):
self.assertEqual(ref_ranks, ranks)
self.assertEqual(
dp_mesh["dp_shard"]._dim_group_names,
ref_mesh["dp_shard"]._dim_group_names,
)
@with_comms()
def test_from_group_with_mesh_shape_2d(self):
@ -456,12 +453,13 @@ class DeviceMeshTestNDim(DTensorTestBase):
mesh_dim_names=("dp_replicate", "dp_shard"),
)
ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
for (_, ref_ranks, _), (_, ranks, _) in zip(
ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
for mesh_dim_group, ref_mesh_dim_group in zip(
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
):
self.assertEqual(ref_ranks, ranks)
mesh_dim_group_ranks = dist.get_process_group_ranks(mesh_dim_group)
ref_mesh_dim_group_ranks = dist.get_process_group_ranks(ref_mesh_dim_group)
self.assertEqual(mesh_dim_group_ranks, ref_mesh_dim_group_ranks)
# check both the 2d mesh and the submeshes are exactly the same.
self.assertEqual(dp_mesh, ref_mesh)
self.assertEqual(dp_mesh["dp_replicate"], ref_mesh["dp_replicate"])

View File

@ -731,8 +731,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
# TODO: it should run collective in the whole mesh instead of dim 0
tag, rankset, _ = group._dim_group_infos[0]
pg = group.get_group()
rankset = dist.get_process_group_ranks(pg)
group_size = len(rankset)
tag = tag or c10d._get_group_tag(pg)
elif isinstance(group, tuple):
if (
len(group) == 2
@ -741,8 +743,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
):
dmesh = group[0]
dim = group[1]
tag, rankset, _ = dmesh._dim_group_infos[dim]
pg = dmesh.get_group(dim)
rankset = dist.get_process_group_ranks(pg)
group_size = len(rankset)
tag = tag or c10d._get_group_tag(pg)
else:
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
else:
@ -767,7 +771,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
assert group.ndim == 1, (
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
return group._dim_group_infos[0][2]
return group._dim_group_names[0]
elif isinstance(group, tuple):
if (
len(group) == 2
@ -776,7 +780,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
):
dmesh = group[0]
dim = group[1]
return dmesh._dim_group_infos[dim][2]
return dmesh._dim_group_names[dim]
else:
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
elif isinstance(group, list):

View File

@ -38,9 +38,8 @@ if not is_available():
else:
from torch._C._distributed_c10d import Backend as C10dBackend
from torch.distributed.distributed_c10d import (
_find_pg_by_ranks_and_tag,
_get_default_group,
_get_group_tag,
_resolve_process_group,
get_backend,
get_process_group_ranks,
get_rank,
@ -103,7 +102,7 @@ else:
mesh_tensor = device_mesh.mesh
# slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims.
slice_dim_idx = []
slice_dim_group_info = []
slice_dim_group_name = []
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
# flattened mesh tensor.
num_dims_flatten = 0
@ -121,15 +120,15 @@ else:
# then the final slice_dim_idx should be [0, 1, 2].
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
num_dims_flatten += len(mesh_dim_indices) - 1
slice_dim_group_info.append(
slice_dim_group_name.append(
self.root_to_flatten_mapping[device_mesh][
mesh_dim_name
]._dim_group_infos[0]
]._dim_group_names[0]
)
else:
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
slice_dim_group_info.append(
device_mesh._dim_group_infos[mesh_dim_indices[0]]
slice_dim_group_name.append(
device_mesh._dim_group_names[mesh_dim_indices[0]]
)
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
@ -155,7 +154,7 @@ else:
if cur_rank in mesh_nd:
res_submesh = submesh
res_submesh._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined]
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined]
self.child_to_root_mapping[res_submesh] = device_mesh
return res_submesh
@ -360,8 +359,8 @@ else:
mesh_dim_names=(mesh_dim_name,),
_init_backend=False,
)
submesh._dim_group_infos = (
[device_mesh._dim_group_infos[mesh_dim]]
submesh._dim_group_names = (
[device_mesh._dim_group_names[mesh_dim]]
if cur_rank in mesh_1d
else []
)
@ -496,13 +495,10 @@ else:
return _get_default_group()
def _init_process_groups(self):
# tag/ranks/group_name associated with each mesh dimension, each
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
#
# TODO(yifu): remove tag and ranks once we fully migrate to native
# functional collectives. See details in:
# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
dim_group_infos: list[tuple[str, list[int], str]] = []
dim_group_names: list[str] = []
default_group = _get_default_group()
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
@ -519,13 +515,7 @@ else:
and get_backend(default_group) == "gloo"
else default_group
)
dim_group_infos.append(
(
_get_group_tag(dim_group),
ranks,
dim_group.group_name,
)
)
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(self.mesh.ndim):
@ -579,10 +569,9 @@ else:
has_split_group = True
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
# the current rank is in the subgroup.
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_infos` list whenever necessary.
# along with appending information to the `dim_group_names` list whenever necessary.
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
@ -599,19 +588,13 @@ else:
# only add to dim_groups if the current rank in the subgroup
if self.get_rank() in subgroup_ranks:
if len(dim_group_infos) > dim:
if len(dim_group_names) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
f"in {subgroup_ranks}!"
)
dim_group_infos.append(
(
_get_group_tag(not_none(dim_group)),
subgroup_ranks,
dim_group.group_name,
)
)
self._dim_group_infos = dim_group_infos
dim_group_names.append(dim_group.group_name)
self._dim_group_names = dim_group_names
def __enter__(self) -> "DeviceMesh":
# set this mesh as the current mesh in mesh env
@ -745,7 +728,7 @@ else:
Returns:
A :class:`ProcessGroup` object.
"""
if not hasattr(self, "_dim_group_infos"):
if not hasattr(self, "_dim_group_names"):
raise RuntimeError("DeviceMesh process groups not initialized!")
if self.mesh.ndim > 1 and mesh_dim is None:
@ -758,28 +741,25 @@ else:
# Quick return if the current device_mesh is a 1D mesh.
if self.mesh.ndim == 1 and mesh_dim is None:
return not_none(
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) # type: ignore[index]
)
return not_none(_resolve_process_group(self._dim_group_names[0]))
root_mesh = _mesh_resources.get_root_mesh(self)
root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get(
root_mesh, None
)
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
dim_group_infos = root_to_flatten_mapping[
dim_group_name = root_to_flatten_mapping[
mesh_dim # type: ignore[index]
]._dim_group_infos[0][:2]
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
]._dim_group_names[0]
return not_none(_resolve_process_group(dim_group_name))
else:
mesh_dim = (
_mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
if isinstance(mesh_dim, str)
else mesh_dim
)
return not_none(
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index]
)
assert isinstance(mesh_dim, int)
return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
def get_all_groups(self) -> list[ProcessGroup]:
"""
@ -852,9 +832,7 @@ else:
mesh_dim_names=mesh_dim_names,
_init_backend=False,
)
device_mesh._dim_group_infos = [
(_get_group_tag(group), group_ranks, group.group_name)
]
device_mesh._dim_group_names = [group.group_name]
return device_mesh
# nD scenario
@ -880,14 +858,7 @@ else:
device_mesh = DeviceMesh(
device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
)
device_mesh._dim_group_infos = [
(
_get_group_tag(group),
get_process_group_ranks(group),
group.group_name,
)
for group in groups
]
device_mesh._dim_group_names = [group.group_name for group in groups]
return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int: