mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[device_mesh] replace dim_group_info with group_name (#150898)
as titled, there's no need to maintain a dim_group_info anymore, we can simply maintain a list of group_name instead. This will simplify the logic Pull Request resolved: https://github.com/pytorch/pytorch/pull/150898 Approved by: https://github.com/tianyu-l, https://github.com/fegin
This commit is contained in:
parent
9c3cef437c
commit
9df9d9ded0
|
|
@ -840,7 +840,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
|||
# since the ref has a parent mesh, while the `from_group` one does not
|
||||
self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh)
|
||||
self.assertEqual(dp_mesh._coordinate_on_dim, ref_dp_mesh._coordinate_on_dim)
|
||||
self.assertEqual(dp_mesh._dim_group_infos, ref_dp_mesh._dim_group_infos)
|
||||
self.assertEqual(dp_mesh._dim_group_names, ref_dp_mesh._dim_group_names)
|
||||
|
||||
# Check 1D FSDP forward/backward parity over the DP mesh
|
||||
# NOTE: We cannot use 2D DTensor-based training here because the DP
|
||||
|
|
@ -916,12 +916,6 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
|||
)
|
||||
self.assertEqual(mesh.mesh, ref_mesh.mesh)
|
||||
self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim)
|
||||
for (_, ranks, _), (_, ref_ranks, _) in zip(
|
||||
mesh._dim_group_infos, ref_mesh._dim_group_infos
|
||||
):
|
||||
# Since we manually constructed new subgroups, the test and ref
|
||||
# groups are not the same
|
||||
self.assertEqual(ranks, ref_ranks)
|
||||
for mesh_dim_name in mesh_dim_names:
|
||||
child_mesh = mesh[mesh_dim_name]
|
||||
ref_child_mesh = ref_mesh[mesh_dim_name]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
||||
|
|
@ -197,7 +198,7 @@ class DeviceMeshTest(DTensorTestBase):
|
|||
local_tensor = torch.randn(2, 8)
|
||||
global_tensor = funcol.all_gather_tensor(
|
||||
local_tensor, gather_dim=0, group=(mesh, 0)
|
||||
)
|
||||
).wait()
|
||||
self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))
|
||||
|
||||
@with_comms
|
||||
|
|
@ -208,7 +209,7 @@ class DeviceMeshTest(DTensorTestBase):
|
|||
mesh_pg = ref_global_mesh.get_group()
|
||||
global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type)
|
||||
self.assertEqual(ref_global_mesh, global_mesh)
|
||||
self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
|
||||
self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
|
||||
self.assertEqual(
|
||||
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
|
||||
)
|
||||
|
|
@ -217,7 +218,7 @@ class DeviceMeshTest(DTensorTestBase):
|
|||
mesh_pg, self.device_type, mesh=torch.arange(self.world_size)
|
||||
)
|
||||
self.assertEqual(ref_global_mesh, global_mesh)
|
||||
self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
|
||||
self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
|
||||
self.assertEqual(
|
||||
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
|
||||
)
|
||||
|
|
@ -396,24 +397,20 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
|||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
|
||||
ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
|
||||
for (_, ref_ranks, _), (_, ranks, _) in zip(
|
||||
ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
|
||||
):
|
||||
self.assertEqual(ref_ranks, ranks)
|
||||
ref_mesh_dp_dim_group_names = ref_mesh._dim_group_names[:2]
|
||||
self.assertEqual(ref_mesh_dp_dim_group_names, dp_mesh._dim_group_names[:2])
|
||||
# Cannot check directly for mesh equality since parent meshes are not
|
||||
# the same since the ref's parent mesh is 3D
|
||||
self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh)
|
||||
for (_, ref_ranks, _), (_, ranks, _) in zip(
|
||||
dp_mesh["dp_replicate"]._dim_group_infos,
|
||||
ref_mesh["dp_replicate"]._dim_group_infos,
|
||||
):
|
||||
self.assertEqual(ref_ranks, ranks)
|
||||
self.assertEqual(
|
||||
dp_mesh["dp_replicate"]._dim_group_names,
|
||||
ref_mesh["dp_replicate"]._dim_group_names,
|
||||
)
|
||||
self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh)
|
||||
for (_, ref_ranks, _), (_, ranks, _) in zip(
|
||||
dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos
|
||||
):
|
||||
self.assertEqual(ref_ranks, ranks)
|
||||
self.assertEqual(
|
||||
dp_mesh["dp_shard"]._dim_group_names,
|
||||
ref_mesh["dp_shard"]._dim_group_names,
|
||||
)
|
||||
|
||||
@with_comms()
|
||||
def test_from_group_with_mesh_shape_2d(self):
|
||||
|
|
@ -456,12 +453,13 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
|||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
|
||||
ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
|
||||
for (_, ref_ranks, _), (_, ranks, _) in zip(
|
||||
ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
|
||||
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
|
||||
for mesh_dim_group, ref_mesh_dim_group in zip(
|
||||
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
|
||||
):
|
||||
self.assertEqual(ref_ranks, ranks)
|
||||
|
||||
mesh_dim_group_ranks = dist.get_process_group_ranks(mesh_dim_group)
|
||||
ref_mesh_dim_group_ranks = dist.get_process_group_ranks(ref_mesh_dim_group)
|
||||
self.assertEqual(mesh_dim_group_ranks, ref_mesh_dim_group_ranks)
|
||||
# check both the 2d mesh and the submeshes are exactly the same.
|
||||
self.assertEqual(dp_mesh, ref_mesh)
|
||||
self.assertEqual(dp_mesh["dp_replicate"], ref_mesh["dp_replicate"])
|
||||
|
|
|
|||
|
|
@ -731,8 +731,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
|
|||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
# TODO: it should run collective in the whole mesh instead of dim 0
|
||||
tag, rankset, _ = group._dim_group_infos[0]
|
||||
pg = group.get_group()
|
||||
rankset = dist.get_process_group_ranks(pg)
|
||||
group_size = len(rankset)
|
||||
tag = tag or c10d._get_group_tag(pg)
|
||||
elif isinstance(group, tuple):
|
||||
if (
|
||||
len(group) == 2
|
||||
|
|
@ -741,8 +743,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
|
|||
):
|
||||
dmesh = group[0]
|
||||
dim = group[1]
|
||||
tag, rankset, _ = dmesh._dim_group_infos[dim]
|
||||
pg = dmesh.get_group(dim)
|
||||
rankset = dist.get_process_group_ranks(pg)
|
||||
group_size = len(rankset)
|
||||
tag = tag or c10d._get_group_tag(pg)
|
||||
else:
|
||||
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
|
||||
else:
|
||||
|
|
@ -767,7 +771,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
|||
assert group.ndim == 1, (
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
return group._dim_group_infos[0][2]
|
||||
return group._dim_group_names[0]
|
||||
elif isinstance(group, tuple):
|
||||
if (
|
||||
len(group) == 2
|
||||
|
|
@ -776,7 +780,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
|||
):
|
||||
dmesh = group[0]
|
||||
dim = group[1]
|
||||
return dmesh._dim_group_infos[dim][2]
|
||||
return dmesh._dim_group_names[dim]
|
||||
else:
|
||||
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
|
||||
elif isinstance(group, list):
|
||||
|
|
|
|||
|
|
@ -38,9 +38,8 @@ if not is_available():
|
|||
else:
|
||||
from torch._C._distributed_c10d import Backend as C10dBackend
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_find_pg_by_ranks_and_tag,
|
||||
_get_default_group,
|
||||
_get_group_tag,
|
||||
_resolve_process_group,
|
||||
get_backend,
|
||||
get_process_group_ranks,
|
||||
get_rank,
|
||||
|
|
@ -103,7 +102,7 @@ else:
|
|||
mesh_tensor = device_mesh.mesh
|
||||
# slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims.
|
||||
slice_dim_idx = []
|
||||
slice_dim_group_info = []
|
||||
slice_dim_group_name = []
|
||||
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
|
||||
# flattened mesh tensor.
|
||||
num_dims_flatten = 0
|
||||
|
|
@ -121,15 +120,15 @@ else:
|
|||
# then the final slice_dim_idx should be [0, 1, 2].
|
||||
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
|
||||
num_dims_flatten += len(mesh_dim_indices) - 1
|
||||
slice_dim_group_info.append(
|
||||
slice_dim_group_name.append(
|
||||
self.root_to_flatten_mapping[device_mesh][
|
||||
mesh_dim_name
|
||||
]._dim_group_infos[0]
|
||||
]._dim_group_names[0]
|
||||
)
|
||||
else:
|
||||
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
|
||||
slice_dim_group_info.append(
|
||||
device_mesh._dim_group_infos[mesh_dim_indices[0]]
|
||||
slice_dim_group_name.append(
|
||||
device_mesh._dim_group_names[mesh_dim_indices[0]]
|
||||
)
|
||||
|
||||
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
|
||||
|
|
@ -155,7 +154,7 @@ else:
|
|||
if cur_rank in mesh_nd:
|
||||
res_submesh = submesh
|
||||
|
||||
res_submesh._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined]
|
||||
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined]
|
||||
self.child_to_root_mapping[res_submesh] = device_mesh
|
||||
|
||||
return res_submesh
|
||||
|
|
@ -360,8 +359,8 @@ else:
|
|||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_backend=False,
|
||||
)
|
||||
submesh._dim_group_infos = (
|
||||
[device_mesh._dim_group_infos[mesh_dim]]
|
||||
submesh._dim_group_names = (
|
||||
[device_mesh._dim_group_names[mesh_dim]]
|
||||
if cur_rank in mesh_1d
|
||||
else []
|
||||
)
|
||||
|
|
@ -496,13 +495,10 @@ else:
|
|||
return _get_default_group()
|
||||
|
||||
def _init_process_groups(self):
|
||||
# tag/ranks/group_name associated with each mesh dimension, each
|
||||
# group_name associated with each mesh dimension, each
|
||||
# mesh dimension should have one sub-group per rank
|
||||
#
|
||||
# TODO(yifu): remove tag and ranks once we fully migrate to native
|
||||
# functional collectives. See details in:
|
||||
# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
|
||||
dim_group_infos: list[tuple[str, list[int], str]] = []
|
||||
dim_group_names: list[str] = []
|
||||
default_group = _get_default_group()
|
||||
|
||||
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
|
||||
|
|
@ -519,13 +515,7 @@ else:
|
|||
and get_backend(default_group) == "gloo"
|
||||
else default_group
|
||||
)
|
||||
dim_group_infos.append(
|
||||
(
|
||||
_get_group_tag(dim_group),
|
||||
ranks,
|
||||
dim_group.group_name,
|
||||
)
|
||||
)
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
|
|
@ -579,10 +569,9 @@ else:
|
|||
has_split_group = True
|
||||
|
||||
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
|
||||
# and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
|
||||
# the current rank is in the subgroup.
|
||||
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
|
||||
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
|
||||
# along with appending information to the `dim_group_infos` list whenever necessary.
|
||||
# along with appending information to the `dim_group_names` list whenever necessary.
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
|
||||
|
|
@ -599,19 +588,13 @@ else:
|
|||
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if len(dim_group_infos) > dim:
|
||||
if len(dim_group_names) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_infos.append(
|
||||
(
|
||||
_get_group_tag(not_none(dim_group)),
|
||||
subgroup_ranks,
|
||||
dim_group.group_name,
|
||||
)
|
||||
)
|
||||
self._dim_group_infos = dim_group_infos
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
self._dim_group_names = dim_group_names
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
# set this mesh as the current mesh in mesh env
|
||||
|
|
@ -745,7 +728,7 @@ else:
|
|||
Returns:
|
||||
A :class:`ProcessGroup` object.
|
||||
"""
|
||||
if not hasattr(self, "_dim_group_infos"):
|
||||
if not hasattr(self, "_dim_group_names"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
|
||||
if self.mesh.ndim > 1 and mesh_dim is None:
|
||||
|
|
@ -758,28 +741,25 @@ else:
|
|||
|
||||
# Quick return if the current device_mesh is a 1D mesh.
|
||||
if self.mesh.ndim == 1 and mesh_dim is None:
|
||||
return not_none(
|
||||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) # type: ignore[index]
|
||||
)
|
||||
return not_none(_resolve_process_group(self._dim_group_names[0]))
|
||||
|
||||
root_mesh = _mesh_resources.get_root_mesh(self)
|
||||
root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get(
|
||||
root_mesh, None
|
||||
)
|
||||
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
|
||||
dim_group_infos = root_to_flatten_mapping[
|
||||
dim_group_name = root_to_flatten_mapping[
|
||||
mesh_dim # type: ignore[index]
|
||||
]._dim_group_infos[0][:2]
|
||||
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
|
||||
]._dim_group_names[0]
|
||||
return not_none(_resolve_process_group(dim_group_name))
|
||||
else:
|
||||
mesh_dim = (
|
||||
_mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
|
||||
if isinstance(mesh_dim, str)
|
||||
else mesh_dim
|
||||
)
|
||||
return not_none(
|
||||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index]
|
||||
)
|
||||
assert isinstance(mesh_dim, int)
|
||||
return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
|
||||
|
||||
def get_all_groups(self) -> list[ProcessGroup]:
|
||||
"""
|
||||
|
|
@ -852,9 +832,7 @@ else:
|
|||
mesh_dim_names=mesh_dim_names,
|
||||
_init_backend=False,
|
||||
)
|
||||
device_mesh._dim_group_infos = [
|
||||
(_get_group_tag(group), group_ranks, group.group_name)
|
||||
]
|
||||
device_mesh._dim_group_names = [group.group_name]
|
||||
return device_mesh
|
||||
|
||||
# nD scenario
|
||||
|
|
@ -880,14 +858,7 @@ else:
|
|||
device_mesh = DeviceMesh(
|
||||
device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
|
||||
)
|
||||
device_mesh._dim_group_infos = [
|
||||
(
|
||||
_get_group_tag(group),
|
||||
get_process_group_ranks(group),
|
||||
group.group_name,
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
device_mesh._dim_group_names = [group.group_name for group in groups]
|
||||
return device_mesh
|
||||
|
||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user