Fix docstring errors (#112693)

This PR reduces docstring erros to 0 from total 128. This can be verified by running, pydocstyle path-to-distributed_c10d.py --count

Where, path-to-distributed_c10d.py is `torch/distributed/distributed_c10d.py`

BEFORE the PR:
`pydocstyle torch/distributed/distributed_c10d.py --count`
128
AFTER the PR:
`pydocstyle torch/distributed/distributed_c10d.py --count`
0

Fixes #112640

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112693
Approved by: https://github.com/H-Huang
This commit is contained in:
Sahdev Zala 2023-11-06 18:45:00 +00:00 committed by PyTorch MergeBot
parent 5248bc9c8e
commit c6ecd018d5

View File

@ -1,3 +1,5 @@
"""Distributed Collective Communication (c10d)."""
import itertools
import collections.abc
import contextlib
@ -134,6 +136,7 @@ PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
# We'd like calls to unsupported ops to error out accordingly,
# rather than returning garbage values.
def supports_complex(reduceOp: ReduceOp) -> bool:
"""Return true if reduce ops is supported. False otherwise."""
denyList = [
ReduceOp.MAX,
ReduceOp.MIN,
@ -147,8 +150,9 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
class Backend:
"""
An enum-like class of available backends: GLOO, NCCL, UCC, MPI, and other registered
backends.
An enum-like class for backends.
Available backends: GLOO, NCCL, UCC, MPI, and other registered backends.
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
be accessed as attributes, e.g., ``Backend.NCCL``.
@ -195,6 +199,7 @@ class Backend:
}
def __new__(cls, name: str):
"""Create and return a new instance of the class."""
if not isinstance(name, str):
raise ValueError(f"Backend name must be a string, but got: {name}")
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
@ -206,7 +211,7 @@ class Backend:
@classmethod
def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None):
"""
Registers a new backend with the given name and instantiating function.
Register a new backend with the given name and instantiating function.
This class method is used by 3rd party ``ProcessGroup`` extension to
register new backends.
@ -266,8 +271,10 @@ class Backend:
Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)
class BackendConfig:
"""Backend configuration class."""
def __init__(self, backend: Union[str, Backend]):
"""Init."""
self.device_backend_map: Dict[torch.device, Backend] = {}
if backend == Backend.UNDEFINED:
@ -325,16 +332,18 @@ class BackendConfig:
)
def __repr__(self):
# string with all the device:backend pairs separated by commas
"""Return all the device:backend pairs separated by commas."""
return ",".join(f"{device}:{backend}" for device, backend in self.device_backend_map.items())
def get_device_backend_map(self):
"""Return backend map of the device."""
return self.device_backend_map
class _reduce_op:
r"""
Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
``MIN``, and ``MAX``.
Deprecated enum-like class.
For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``.
:class:`~torch.distributed.ReduceOp` is recommended to use instead.
"""
@ -377,6 +386,7 @@ class P2POp:
def __init__(self, op: Callable, tensor: torch.Tensor, peer: int,
group: Optional[ProcessGroup] = None, tag: int = 0):
"""Init."""
self.op = op
self.tensor = tensor
self.peer = peer
@ -385,6 +395,7 @@ class P2POp:
def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int,
group: Optional[ProcessGroup] = None, tag: int = 0):
"""Create and return a new instance of the class."""
_check_op(op)
_check_single_tensor(tensor, "tensor")
return object.__new__(cls)
@ -426,11 +437,13 @@ _pg_to_tag: Dict[ProcessGroup, str] = {}
class _World:
"""
Container class for c10d process group state.
This is used during registration and lookup of PG state.
.. warning:: This is an experimental API intended to expose the inner workings
of c10d and is subject to change..
"""
def __init__(self):
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
@ -439,8 +452,10 @@ class _World:
@property
def default_pg(self):
"""
The default ProcessGroup includes all ranks of the cluster.
This is used by c10d APIs when a ProcessGroup is needed but None is provided.
Process group that includes all ranks of the cluster.
This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed
but None is provided.
"""
return self._default_pg
@ -451,7 +466,8 @@ class _World:
@property
def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Optional[Store]]]:
"""
Cached process groups
Provide Mapping from ProcessGroup to backend name and store.
For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
For MPI pg, it is a map from ProcessGroup to (Backend, None)
@ -473,7 +489,8 @@ class _World:
@property
def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
"""
Process group's global rank to local rank mapping
Process group's global rank to local rank mapping.
TODO don't expose the map, expose fine grained ops
"""
global _pg_group_ranks
@ -482,7 +499,8 @@ class _World:
@property
def pg_backend_config(self) -> Dict[ProcessGroup, str]:
"""
Process group's backend config
Process group's backend config.
TODO don't expose the map, expose fine grained ops
"""
global _pg_backend_config
@ -500,9 +518,7 @@ class _World:
@group_count.setter
def group_count(self, value):
"""
Count is used when computing the name of ProcessGroups when using global synchronization.
"""
"""Use to compute the name of ProcessGroups when using global synchronization."""
global _group_count
_group_count = value
@ -527,8 +543,9 @@ class _World:
@property
def pg_config_info(self) -> List[Dict[str, Union[int, str]]]:
"""
Returns a list of dict with process groups and backends with their unique IDs
and configurations (types and ranks).
Return a list of dict with process groups and backends.
Along with their unique IDs and configurations (types and ranks).
"""
config_info = []
default_pg_size = _get_group_size(None)
@ -556,9 +573,11 @@ _world = _World()
class _WorldMeta(type):
"""
Meta class of ``group`` and ``GroupMember`` so they
can have the class property ``WORLD``.
Meta class of ``group`` and ``GroupMember``.
Allows them to have the class property ``WORLD``.
"""
# Points to the default PG once initialized.
@property
def WORLD(cls) -> Optional[ProcessGroup]:
@ -569,9 +588,13 @@ class _WorldMeta(type):
_world.default_pg = pg
class group(metaclass=_WorldMeta):
"""Group class. Placeholder."""
pass
class GroupMember(metaclass=_WorldMeta):
"""Group member class."""
NON_GROUP_MEMBER = -100
@ -582,7 +605,8 @@ STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device:
"""
Returns the device to use with ``group`` for control flow usage (object collectives, barrier).
Return the device to use with ``group`` for control flow usage (object collectives, barrier).
There are selection rules:
1. If user specifies exactly one backend in ``init_process_group`` call:
use that backend
@ -650,6 +674,8 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
@_time_logger
def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, logging_interval=timedelta(seconds=10)):
"""
Store based barrier for synchronizing processes.
Barrier based on store which is used for synchronizing processes after
``init_process_group`` or ``new_group``. Intended to be used only with
those two methods and is not a generic alternative to ``barrier()``.
@ -700,9 +726,7 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log
def _rank_not_in_group(group: ProcessGroup):
"""
Helper that checks if the current process's rank is not in a given group.
"""
"""Check if the current process's rank is not in a given group."""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER
@ -767,9 +791,7 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
# TODO: remove this once the ecosystem moves away from it.
def _get_global_rank(group, rank):
"""
This method is deprecated, please use get_global_rank.
"""
"""Use get_global_rank as this method is deprecated."""
warnings.warn(
"torch.distributed.distributed_c10d._get_global_rank is deprecated "
"please use torch.distributed.distributed_c10d.get_global_rank instead"
@ -790,9 +812,7 @@ def get_process_group_ranks(group: ProcessGroup):
return list(_world.pg_group_ranks[group].keys())
def _get_group_size(group):
"""
Helper that gets a given group's world size.
"""
"""Get a given group's world size."""
if group is GroupMember.WORLD or group is None:
default_pg = _get_default_group()
return default_pg.size()
@ -800,9 +820,7 @@ def _get_group_size(group):
def _check_single_tensor(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a single tensor.
"""
"""Check that the parameter ``param_name`` is a single tensor."""
if not isinstance(param, torch.Tensor):
raise TypeError(
f"Invalid function argument. Expected parameter `{param_name}` to be of type torch.Tensor."
@ -810,9 +828,7 @@ def _check_single_tensor(param, param_name):
def _check_tensor_list(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a list of tensors.
"""
"""Check that the parameter ``param_name`` is a list of tensors."""
if not isinstance(param, list) or not all(
isinstance(p, torch.Tensor) for p in param
):
@ -842,9 +858,7 @@ def _ensure_all_tensors_same_dtype(*tensors) -> None:
def _check_op(op):
"""
Helper to check that the ``op`` is either isend or irecv.
"""
"""Check that the ``op`` is either isend or irecv."""
if op not in [isend, irecv]:
raise ValueError(
"Invalid ``op``. Expected ``op`` "
@ -855,8 +869,9 @@ def _check_op(op):
def _check_p2p_op_list(p2p_op_list):
"""
Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
all ops use the same group.
Check that the ``p2p_op_list`` is a list of P2POp instances.
Also, check that all ops use the same group.
"""
if not isinstance(p2p_op_list, list) or not all(
isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
@ -872,35 +887,29 @@ def _check_p2p_op_list(p2p_op_list):
def is_mpi_available() -> bool:
"""
Checks if the MPI backend is available.
"""
"""Check if the MPI backend is available."""
return _MPI_AVAILABLE
def is_nccl_available() -> bool:
"""
Checks if the NCCL backend is available.
"""
"""Check if the NCCL backend is available."""
return _NCCL_AVAILABLE
def is_gloo_available() -> bool:
"""
Checks if the Gloo backend is available.
"""
"""Check if the Gloo backend is available."""
return _GLOO_AVAILABLE
def is_ucc_available() -> bool:
"""
Checks if the UCC backend is available.
"""
"""Check if the UCC backend is available."""
return _UCC_AVAILABLE
def is_backend_available(backend: str) -> bool:
"""
Check backend availability.
Checks if the given backend is available and supports the built-in backends or
third-party backends through function ``Backend.register_backend``.
@ -918,16 +927,15 @@ def is_backend_available(backend: str) -> bool:
def is_initialized() -> bool:
"""
Checking if the default process group has been initialized
"""
"""Check if the default process group has been initialized."""
return GroupMember.WORLD is not None
def is_torchelastic_launched() -> bool:
"""
Checks whether this process was launched with ``torch.distributed.elastic``
(aka torchelastic). The existence of ``TORCHELASTIC_RUN_ID`` environment
Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic).
The existence of ``TORCHELASTIC_RUN_ID`` environment
variable is used as a proxy to determine whether the current process
was launched with torchelastic. This is a reasonable proxy since
``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a
@ -945,9 +953,7 @@ def _is_barrier_after_init() -> int:
def _get_default_group():
"""
Getting the default process group created by init_process_group
"""
"""Get the default process group created by init_process_group."""
if not is_initialized():
raise ValueError(
"Default process group has not been initialized, "
@ -957,9 +963,7 @@ def _get_default_group():
def _get_default_store():
"""
Getting the default store created by init_process_group
"""
"""Get the default store created by init_process_group."""
if not is_initialized():
raise ValueError(
"Default process group has not been initialized, "
@ -976,6 +980,18 @@ def _update_default_pg(pg):
torch._C._distributed_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
"""
Return the backend configuration of the given process group.
Args:
group (ProcessGroup, optional): The process group to work on. The
default is the general main process group. If another specific group
is specified, the calling process must be part of :attr:`group`.
Returns:
The backend configuration of the given process group as a lower case string.
"""
if group is None:
pg = _get_default_group()
else:
@ -988,7 +1004,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
def get_backend(group: Optional[ProcessGroup] = None) -> str:
"""
Returns the backend of the given process group.
Return the backend of the given process group.
Args:
group (ProcessGroup, optional): The process group to work on. The
@ -1023,8 +1039,9 @@ def init_process_group(
pg_options: Optional[Any] = None,
):
"""
Initializes the default distributed process group, and this will also
initialize the distributed package.
Initialize the default distributed process group.
This will also initialize the distributed package.
There are 2 main ways to initialize a process group:
1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
@ -1389,7 +1406,7 @@ def _new_process_group_helper(
def destroy_process_group(group: Optional[ProcessGroup] = None):
"""
Destroy a given process group, and deinitialize the distributed package
Destroy a given process group, and deinitialize the distributed package.
Args:
group (ProcessGroup, optional): The process group to be destroyed, if
@ -1469,8 +1486,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
def get_rank(group: Optional[ProcessGroup] = None) -> int:
"""
Returns the rank of the current process in the provided ``group`` or the
default group if none was provided.
Return the rank of the current process in the provided ``group``, default otherwise.
Rank is a unique identifier assigned to each process within a distributed
process group. They are always consecutive integers ranging from 0 to
@ -1497,7 +1513,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int:
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
"""
Returns the number of processes in the current process group
Return the number of processes in the current process group.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
@ -1516,7 +1532,7 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> Work:
"""
Sends a tensor asynchronously.
Send a tensor asynchronously.
.. warning::
Modifying ``tensor`` before the request completes causes undefined
@ -1592,7 +1608,7 @@ def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proce
@_exception_logger
def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> None:
"""
Sends a tensor synchronously.
Send a tensor synchronously.
Args:
tensor (Tensor): Tensor to send.
@ -1692,7 +1708,7 @@ def _coalescing_manager(
async_ops: Optional[bool] = False,
):
"""
A context manager used to coalesce collectives or P2P operations when possible.
Context manager used to coalesce collectives or P2P operations when possible.
Args:
group (`ProcessGroup`, optional): The process group to work on. If None,
@ -1836,8 +1852,7 @@ def batch_isend_irecv(p2p_op_list):
@_exception_logger
def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=0):
"""
Broadcasts the tensor to the whole group with multiple GPU tensors
per node.
Broadcasts the tensor to the whole group with multiple GPU tensors per node.
``tensor`` must have the same number of elements in all the GPUs from
all processes participating in the collective. each tensor in the list must
@ -1938,8 +1953,9 @@ def broadcast(tensor, src, group=None, async_op=False):
@_exception_logger
def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, group=None, async_op=False):
r"""
Reduces the tensor data across all machines in such a way that all get
the final result. This function reduces a number of tensors on every node,
Reduces the tensor data across all machines in a way that all get the final result.
This function reduces a number of tensors on every node,
while each tensor resides on different GPUs.
Therefore, the input tensor in the tensor list needs to be GPU tensors.
Also, each tensor in the tensor list needs to reside on a different GPU.
@ -1999,8 +2015,7 @@ def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, group=None, async_op=False
@_exception_logger
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
"""
Reduces the tensor data across all machines in such a way that all get
the final result.
Reduces the tensor data across all machines in a way that all get the final result.
After the call ``tensor`` is going to be bitwise identical in all processes.
@ -2080,6 +2095,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
"""
WARNING: at this time individual shape checking is not implemented across nodes.
For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce
operation will proceed without complaint and return erroneous outputs. This lack
@ -2144,8 +2160,9 @@ def reduce_multigpu(
tensor_list, dst, op=ReduceOp.SUM, group=None, async_op=False, dst_tensor=0
):
"""
Reduces the tensor data on multiple GPUs across all machines. Each tensor
in ``tensor_list`` should reside on a separate GPU
Reduces the tensor data on multiple GPUs across all machines.
Each tensor in ``tensor_list`` should reside on a separate GPU.
Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst``
is going to receive the final result.
@ -2252,6 +2269,7 @@ def all_gather_multigpu(
):
"""
Gathers tensors from the whole group in a list.
Each tensor in ``tensor_list`` should reside on a separate GPU
Only nccl backend is currently supported
@ -2343,9 +2361,10 @@ def _tensor_to_object(tensor, tensor_size):
@_exception_logger
def all_gather_object(object_list, obj, group=None):
"""
Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.
Gathers picklable objects from the whole group into a list.
Similar to :func:`all_gather`, but Python objects can be passed in.
Note that the object must be picklable in order to be gathered.
Args:
object_list (list[Any]): Output list. It should be correctly sized as the
@ -2436,6 +2455,7 @@ def all_gather_object(object_list, obj, group=None):
def gather_object(obj, object_gather_list=None, dst=0, group=None):
"""
Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.
@ -2545,8 +2565,9 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
@_exception_logger
def broadcast_object_list(object_list, src=0, group=None, device=None):
"""
Broadcasts picklable objects in ``object_list`` to the whole group. Similar
to :func:`broadcast`, but Python objects can be passed in.
Broadcasts picklable objects in ``object_list`` to the whole group.
Similar to :func:`broadcast`, but Python objects can be passed in.
Note that all objects in ``object_list`` must be picklable in order to be
broadcasted.
@ -2657,8 +2678,9 @@ def scatter_object_list(
scatter_object_output_list, scatter_object_input_list, src=0, group=None
):
"""
Scatters picklable objects in ``scatter_object_input_list`` to the whole
group. Similar to :func:`scatter`, but Python objects can be passed in. On
Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
Similar to :func:`scatter`, but Python objects can be passed in. On
each rank, the scattered object will be stored as the first element of
``scatter_object_output_list``. Note that all objects in
``scatter_object_input_list`` must be picklable in order to be scattered.
@ -3213,8 +3235,9 @@ def reduce_scatter_multigpu(
output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, group=None, async_op=False
):
"""
Reduce and scatter a list of tensors to the whole group. Only nccl backend
is currently supported.
Reduce and scatter a list of tensors to the whole group.
Only nccl backend is currently supported.
Each tensor in ``output_tensor_list`` should reside on a separate GPU, as
should each list of tensors in ``input_tensor_lists``.
@ -3444,9 +3467,10 @@ def all_to_all_single(
async_op=False,
):
"""
Each process splits input tensor and then scatters the split list
to all processes in a group. Then concatenate the received tensors from all
the processes in the group and return single output tensor.
Split input tensor and then scatter the split list to all processes in a group.
Later the received tensors are concatenated from all the processes in the group
and returned as a single output tensor.
Complex tensors are supported.
@ -3568,8 +3592,7 @@ def all_to_all_single(
@_exception_logger
def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
"""
Each process scatters list of input tensors to all processes in a group and
return gathered list of tensors in output list.
Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
Complex tensors are supported.
@ -3686,9 +3709,8 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
@_exception_logger
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
"""
Synchronizes all processes.
Synchronize all processes.
This collective blocks processes until the whole group enters this function,
if async_op is False, or if async work handle is called on wait().
@ -3731,14 +3753,13 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
"""
Synchronizes all processes similar to ``torch.distributed.barrier``, but takes
a configurable timeout and is able to report ranks that did not pass this
barrier within that timeout. Specifically, for non-zero ranks, will block
until a send/recv is processed from rank 0. Rank 0 will block until all send
/recv from other ranks are processed, and will report failures for ranks
that failed to respond in time. Note that if one rank does not reach the
monitored_barrier (for example due to a hang), all other ranks would fail
in monitored_barrier.
Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
It is able to report ranks that did not pass this barrier within the provided timeout.
Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.
Rank 0 will block until all send /recv from other ranks are processed, and will report
failures for ranks that failed to respond in time. Note that if one rank does not reach the
monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.
This collective will block all processes/ranks in the group, until the
whole group exits the function successfully, making it useful for debugging
@ -3778,7 +3799,6 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
>>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
>>> # monitored_barrier.
"""
# Need to call rank not in group before using the group, otherwise
# "Invalid process group" error is raised.
if _rank_not_in_group(group):
@ -3834,7 +3854,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
@_time_logger
def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None, use_local_synchronization=False):
"""
Creates a new distributed group.
Create a new distributed group.
This function requires that all processes in the main group (i.e. all
processes that are part of the distributed job) enter this function, even
@ -3915,7 +3935,7 @@ def _new_group_with_tag(
use_local_synchronization=False
):
"""
This is a variant of ``new_group`` that exposes tag creation.
Variant of ``new_group`` that exposes tag creation.
:: N.B. The mechanism is experimental and tied to the functional collectives effort, see
``torch.distributed._functional_collectives`` for reference on how to use it.
@ -4018,7 +4038,9 @@ def new_subgroups(
pg_options=None,
):
"""
Creates subgroups of equal size. By default, it creates intra-machine subgroups,
Create subgroups of equal size.
By default, it creates intra-machine subgroups,
where each of which contains all the ranks of a machine, based on the assumption
that each machine has the same number of devices.
@ -4149,9 +4171,10 @@ def new_subgroups_by_enumeration(
pg_options=None,
):
"""
Creates subgroups by dividing the global world, where the division is specified by
a nested list of ranks. The subgroups cannot have overlap, and some ranks may not have
to be in any subgroup.
Create subgroups by dividing the global world.
The division is specified by a nested list of ranks. The subgroups cannot have
overlap, and some ranks may not have to be in any subgroup.
This is a convenience API that calls ``new_group`` to generate multiple subgroups.
It requires that all processes in the main group (i.e. all
@ -4288,9 +4311,7 @@ def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int)
return _new_group_with_tag(my_ranks, pg_tag=tag)
def _get_group_tag(pg: ProcessGroup) -> str:
"""
Returns the tag associated with ``pg``.
"""
"""Return the tag associated with ``pg``."""
tag = _world.pg_to_tag[pg]
if tag.startswith("user:"):
tag = tag[5:]