[DeviceMesh] Isolate pg creation logic in Device Mesh into a separate func _init_one_process_group (#166614)

To makes pg cache change easier and code modularization, we isolate the logic of process group creation into a separate function named `_init_one_process_group`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166614
Approved by: https://github.com/lw
This commit is contained in:
fduwjj 2025-10-30 08:53:45 -07:00 committed by PyTorch MergeBot
parent 694d205143
commit ba71e9ca9a

View File

@ -350,22 +350,33 @@ else:
return _get_default_group()
@staticmethod
def _init_process_groups(
layout: _MeshLayout,
def _init_one_process_group(
sub_layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
#
dim_group_names: list[str] = []
dim_name: str,
backend_override: BackendConfig,
) -> Optional[str]:
# Generate a 2D global mesh tensor for the current dim for PG creation.
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
# TODO: remove this once we have fixed inside c10d level.
timeout = pg_options._timeout if pg_options else None
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
# If the mesh doesn't have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = f"mesh_{dim_name}"
dim_group = None
default_group = _get_default_group()
if (
len(layout) == 1
and layout.numel() == get_world_size()
and backend_override[0] == (None, None)
# Early return if there is only one sub_layout in the mesh layout.
if sub_layout.numel() == get_world_size() and backend_override == (
None,
None,
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
# Otherwise, create new pg.
@ -380,90 +391,80 @@ else:
and get_backend(default_group) == "gloo"
else default_group
)
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
# the default timeout will be used to override the timeout set in option.
# TODO: remove this once we have fixed inside c10d level.
timeout = pg_options._timeout if pg_options else None
return dim_group.group_name # type: ignore[union-attr]
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
group_desc = (
f"mesh_{mesh_dim_names[dim]}"
if mesh_dim_names
else f"mesh_dim_{dim}"
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
# for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create
# all the subgroups.
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
# mesh, we need to make two API calls per ranks to create all the subgroups.
if (
getattr(default_group, "bound_device_id", None) is not None
and torch.cuda.is_available()
and (
backend is None
or default_group._get_backend(torch.device("cuda")).name()
== backend
)
):
dim_group = split_group(
parent_pg=default_group,
timeout=timeout,
pg_options=pg_options,
split_ranks=pg_ranks_by_dim.tolist(),
group_desc=group_desc,
)
return dim_group.group_name # type: ignore[union-attr]
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_names` list whenever necessary.
pg_name = None
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
dim_group = new_group(
ranks=subgroup_ranks,
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if pg_name is not None:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
pg_name = dim_group.group_name
return pg_name
@staticmethod
def _init_process_groups(
layout: _MeshLayout,
rank_map: torch.Tensor,
mesh_dim_names: Optional[tuple[str, ...]],
backend_override: tuple[BackendConfig, ...],
) -> list[str]:
# group_name associated with each mesh dimension, each
# mesh dimension should have one sub-group per rank
dim_group_names: list[str] = []
# create sub pgs base on the mesh argument specified
for dim in range(len(layout)):
dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
dim_group_names.append(
DeviceMesh._init_one_process_group( # type: ignore[arg-type]
layout[dim], rank_map, dim_name, backend_override[dim]
)
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
# for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create
# all the subgroups.
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
dim_group = None
has_split_group = False
if (
(
bound_device_id := getattr(
default_group, "bound_device_id", None
)
)
is not None
and torch.cuda.is_available()
and (
backend is None
or default_group._get_backend(torch.device("cuda")).name()
== backend
)
):
dim_group = split_group(
parent_pg=default_group,
timeout=timeout,
pg_options=pg_options,
split_ranks=pg_ranks_by_dim.tolist(),
group_desc=group_desc,
)
has_split_group = True
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
# along with appending information to the `dim_group_names` list whenever necessary.
for dim_mesh in pg_ranks_by_dim:
subgroup_ranks = dim_mesh.tolist()
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
# Temporarily reverting to resolve test timeout while root-causing.
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
# pyrefly: ignore [unbound-name]
if bound_device_id is None or not has_split_group:
dim_group = new_group(
ranks=subgroup_ranks,
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
# only add to dim_groups if the current rank in the subgroup
if get_rank() in subgroup_ranks:
if len(dim_group_names) > dim:
raise RuntimeError(
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
f"in {subgroup_ranks}!"
)
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
)
if any(n is None for n in dim_group_names):
assert all(n is None for n in dim_group_names)
return []
return dim_group_names
def _get_root_mesh(self) -> "DeviceMesh":