diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 0cebfaff6d6..1fdc2a13bcd 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1583,6 +1583,7 @@ def init_process_group( group_name: str = "", pg_options: Optional[Any] = None, device_id: Optional[Union[torch.device, int]] = None, + _ranks: Optional[list[int]] = None, ) -> None: """ Initialize the default distributed process group. @@ -1657,6 +1658,8 @@ def init_process_group( want to know NCCL initialization error early, you can also use this field. If an `int` is provided, the API assumes that the accelerator type at compile time will be used. + _ranks: The ranks in the process group. If provided, the process + group name will be the hash of all the ranks in the group. .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source on a system that supports MPI. @@ -1761,7 +1764,10 @@ def init_process_group( internals of c10d. This means we can ignore the value they provide as it not exposed in a public way. """ - group_name = _process_group_name([], use_hashed_name=False) + if _ranks is None or len(_ranks) == 0: + group_name = _process_group_name([], use_hashed_name=False) + else: + group_name = _process_group_name(_ranks, use_hashed_name=True) if backend == Backend.MPI: if world_size != -1 or rank != -1: warnings.warn(