mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DeviceMesh] Isolate pg creation logic in Device Mesh into a separate func _init_one_process_group (#166614)
To makes pg cache change easier and code modularization, we isolate the logic of process group creation into a separate function named `_init_one_process_group`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166614 Approved by: https://github.com/lw
This commit is contained in:
parent
694d205143
commit
ba71e9ca9a
|
|
@ -350,22 +350,33 @@ else:
|
|||
return _get_default_group()
|
||||
|
||||
@staticmethod
|
||||
def _init_process_groups(
|
||||
layout: _MeshLayout,
|
||||
def _init_one_process_group(
|
||||
sub_layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
mesh_dim_names: Optional[tuple[str, ...]],
|
||||
backend_override: tuple[BackendConfig, ...],
|
||||
) -> list[str]:
|
||||
# group_name associated with each mesh dimension, each
|
||||
# mesh dimension should have one sub-group per rank
|
||||
#
|
||||
dim_group_names: list[str] = []
|
||||
dim_name: str,
|
||||
backend_override: BackendConfig,
|
||||
) -> Optional[str]:
|
||||
# Generate a 2D global mesh tensor for the current dim for PG creation.
|
||||
pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
|
||||
backend, pg_options = backend_override
|
||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||
# the default timeout will be used to override the timeout set in option.
|
||||
# TODO: remove this once we have fixed inside c10d level.
|
||||
timeout = pg_options._timeout if pg_options else None
|
||||
|
||||
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
|
||||
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
|
||||
# If the mesh doesn't have a mesh_dim_names, then the group description of the
|
||||
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
|
||||
group_desc = f"mesh_{dim_name}"
|
||||
|
||||
dim_group = None
|
||||
default_group = _get_default_group()
|
||||
|
||||
if (
|
||||
len(layout) == 1
|
||||
and layout.numel() == get_world_size()
|
||||
and backend_override[0] == (None, None)
|
||||
# Early return if there is only one sub_layout in the mesh layout.
|
||||
if sub_layout.numel() == get_world_size() and backend_override == (
|
||||
None,
|
||||
None,
|
||||
):
|
||||
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
||||
# Otherwise, create new pg.
|
||||
|
|
@ -380,90 +391,80 @@ else:
|
|||
and get_backend(default_group) == "gloo"
|
||||
else default_group
|
||||
)
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(len(layout)):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map)
|
||||
backend, pg_options = backend_override[dim]
|
||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||
# the default timeout will be used to override the timeout set in option.
|
||||
# TODO: remove this once we have fixed inside c10d level.
|
||||
timeout = pg_options._timeout if pg_options else None
|
||||
return dim_group.group_name # type: ignore[union-attr]
|
||||
|
||||
# If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
|
||||
# of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
|
||||
# If the mesh doesn't not have a mesh_dim_names, then the group description of the
|
||||
# subgroup would be `mesh_dim_0` and `mesh_dim_1`.
|
||||
group_desc = (
|
||||
f"mesh_{mesh_dim_names[dim]}"
|
||||
if mesh_dim_names
|
||||
else f"mesh_dim_{dim}"
|
||||
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
|
||||
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
|
||||
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
|
||||
# for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create
|
||||
# all the subgroups.
|
||||
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
|
||||
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
|
||||
# mesh, we need to make two API calls per ranks to create all the subgroups.
|
||||
if (
|
||||
getattr(default_group, "bound_device_id", None) is not None
|
||||
and torch.cuda.is_available()
|
||||
and (
|
||||
backend is None
|
||||
or default_group._get_backend(torch.device("cuda")).name()
|
||||
== backend
|
||||
)
|
||||
):
|
||||
dim_group = split_group(
|
||||
parent_pg=default_group,
|
||||
timeout=timeout,
|
||||
pg_options=pg_options,
|
||||
split_ranks=pg_ranks_by_dim.tolist(),
|
||||
group_desc=group_desc,
|
||||
)
|
||||
return dim_group.group_name # type: ignore[union-attr]
|
||||
|
||||
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
|
||||
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
|
||||
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
|
||||
# along with appending information to the `dim_group_names` list whenever necessary.
|
||||
pg_name = None
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
dim_group = new_group(
|
||||
ranks=subgroup_ranks,
|
||||
timeout=timeout,
|
||||
backend=backend,
|
||||
pg_options=pg_options,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if get_rank() in subgroup_ranks:
|
||||
if pg_name is not None:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
pg_name = dim_group.group_name
|
||||
return pg_name
|
||||
|
||||
@staticmethod
|
||||
def _init_process_groups(
|
||||
layout: _MeshLayout,
|
||||
rank_map: torch.Tensor,
|
||||
mesh_dim_names: Optional[tuple[str, ...]],
|
||||
backend_override: tuple[BackendConfig, ...],
|
||||
) -> list[str]:
|
||||
# group_name associated with each mesh dimension, each
|
||||
# mesh dimension should have one sub-group per rank
|
||||
dim_group_names: list[str] = []
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(len(layout)):
|
||||
dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
|
||||
dim_group_names.append(
|
||||
DeviceMesh._init_one_process_group( # type: ignore[arg-type]
|
||||
layout[dim], rank_map, dim_name, backend_override[dim]
|
||||
)
|
||||
|
||||
# If bound_device_id exists, it means the nccl communicator has been eagerly initialized
|
||||
# so that we can use `split_group` to create subgroups through `ncclCommSplit`.
|
||||
# In this case, we only need to make one API call (`split_group``) for the subgroup creation
|
||||
# for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create
|
||||
# all the subgroups.
|
||||
# Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
|
||||
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
|
||||
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
|
||||
dim_group = None
|
||||
has_split_group = False
|
||||
if (
|
||||
(
|
||||
bound_device_id := getattr(
|
||||
default_group, "bound_device_id", None
|
||||
)
|
||||
)
|
||||
is not None
|
||||
and torch.cuda.is_available()
|
||||
and (
|
||||
backend is None
|
||||
or default_group._get_backend(torch.device("cuda")).name()
|
||||
== backend
|
||||
)
|
||||
):
|
||||
dim_group = split_group(
|
||||
parent_pg=default_group,
|
||||
timeout=timeout,
|
||||
pg_options=pg_options,
|
||||
split_ranks=pg_ranks_by_dim.tolist(),
|
||||
group_desc=group_desc,
|
||||
)
|
||||
has_split_group = True
|
||||
|
||||
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
|
||||
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
|
||||
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
|
||||
# along with appending information to the `dim_group_names` list whenever necessary.
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
|
||||
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
|
||||
# Temporarily reverting to resolve test timeout while root-causing.
|
||||
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if bound_device_id is None or not has_split_group:
|
||||
dim_group = new_group(
|
||||
ranks=subgroup_ranks,
|
||||
timeout=timeout,
|
||||
backend=backend,
|
||||
pg_options=pg_options,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if get_rank() in subgroup_ranks:
|
||||
if len(dim_group_names) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {get_rank()} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
|
||||
)
|
||||
if any(n is None for n in dim_group_names):
|
||||
assert all(n is None for n in dim_group_names)
|
||||
return []
|
||||
return dim_group_names
|
||||
|
||||
def _get_root_mesh(self) -> "DeviceMesh":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user