mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
set pg name based on ranks (#166182)
Summary: - in torchft we have multiple default pg's, 1 for each task group - for flight recorder to work, each of these need to have a different name, so entries can be matched - change the `init_process_group` api to optionally take a list of ranks. if provided, we use the hash of the ranks as the name of the pg. for torchft, we'll pass global ranks here so the default pg have a different name on each task group Pull Request resolved: https://github.com/pytorch/pytorch/pull/166182 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
d1a6e006e0
commit
fc540cefd4
|
|
@ -1583,6 +1583,7 @@ def init_process_group(
|
|||
group_name: str = "",
|
||||
pg_options: Optional[Any] = None,
|
||||
device_id: Optional[Union[torch.device, int]] = None,
|
||||
_ranks: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the default distributed process group.
|
||||
|
|
@ -1657,6 +1658,8 @@ def init_process_group(
|
|||
want to know NCCL initialization error early, you can also use this
|
||||
field. If an `int` is provided, the API assumes that the accelerator
|
||||
type at compile time will be used.
|
||||
_ranks: The ranks in the process group. If provided, the process
|
||||
group name will be the hash of all the ranks in the group.
|
||||
|
||||
.. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
|
||||
on a system that supports MPI.
|
||||
|
|
@ -1761,7 +1764,10 @@ def init_process_group(
|
|||
internals of c10d. This means we can ignore the value
|
||||
they provide as it not exposed in a public way.
|
||||
"""
|
||||
group_name = _process_group_name([], use_hashed_name=False)
|
||||
if _ranks is None or len(_ranks) == 0:
|
||||
group_name = _process_group_name([], use_hashed_name=False)
|
||||
else:
|
||||
group_name = _process_group_name(_ranks, use_hashed_name=True)
|
||||
if backend == Backend.MPI:
|
||||
if world_size != -1 or rank != -1:
|
||||
warnings.warn(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user