mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40222 Mention the TensorPipe agent in the RPC docs and give users the information they need to choose which agent to use. ghstack-source-id: 106225711 Test Plan: Export to GitHub, build locally and try out the docs. Differential Revision: D22116494 fbshipit-source-id: 30703ba8410c40f64e785f60d71dfd9faa8de4a1
102 lines
4.1 KiB
Python
102 lines
4.1 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import numbers
|
|
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
def is_available():
|
|
return hasattr(torch._C, "_rpc_init")
|
|
|
|
|
|
if is_available() and not torch._C._rpc_init():
|
|
raise RuntimeError("Failed to initialize torch.distributed.rpc")
|
|
|
|
|
|
if is_available():
|
|
from . import api, backend_registry, functions, _set_profiler_node_id
|
|
from .api import * # noqa: F401
|
|
from .backend_registry import BackendType
|
|
from .server_process_global_profiler import (
|
|
_server_process_global_profile,
|
|
)
|
|
import torch.distributed.autograd as dist_autograd
|
|
|
|
def init_rpc(
|
|
name,
|
|
backend=BackendType.PROCESS_GROUP,
|
|
rank=-1,
|
|
world_size=None,
|
|
rpc_backend_options=None,
|
|
):
|
|
r"""
|
|
Initializes RPC primitives such as the local RPC agent
|
|
and distributed autograd.
|
|
|
|
Initializes the local RPC agent which immediately makes the current
|
|
process ready to send and receive RPCs.
|
|
|
|
Arguments:
|
|
backend (BackendType, optional): The type of RPC backend
|
|
implementation. Supported values include
|
|
``BackendType.PROCESS_GROUP`` (the default) and
|
|
``BackendType.TENSORPIPE``. See :ref:`rpc-backends` for more
|
|
information.
|
|
name (str): a globally unique name of this node. (e.g.,
|
|
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
|
|
Name can only contain number, alphabet, underscore, and/or dash,
|
|
and must be shorter than 128 characters.
|
|
rank (int): a globally unique id/rank of this node.
|
|
world_size (int): The number of workers in the group.
|
|
rpc_backend_options (RpcBackendOptions, optional): The options
|
|
passed to the RpcAgent constructor. It must be an agent-specific
|
|
subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
|
|
and contains agent-specific initialization configurations. By
|
|
default, for all agents, it sets the default timeout to 60
|
|
seconds and performs the rendezvous with an underlying process
|
|
group initialized using ``init_method = "env://"``,
|
|
meaning that environment variables ``MASTER_ADDR`` and
|
|
``MASTER_PORT`` needs to be set properly. See
|
|
:ref:`rpc-backends` for more information and find which options
|
|
are available.
|
|
"""
|
|
|
|
if not rpc_backend_options:
|
|
# default construct a set of RPC backend options.
|
|
rpc_backend_options = backend_registry.construct_rpc_backend_options(
|
|
backend
|
|
)
|
|
|
|
# Rendezvous.
|
|
# This rendezvous state sometimes is destroyed before all processes
|
|
# finishing handshaking. To avoid that issue, we make it global to
|
|
# keep it alive.
|
|
global rendezvous_iterator
|
|
rendezvous_iterator = torch.distributed.rendezvous(
|
|
rpc_backend_options.init_method, rank=rank, world_size=world_size
|
|
)
|
|
store, _, _ = next(rendezvous_iterator)
|
|
|
|
# Initialize autograd before RPC since _init_rpc_backend guarantees all
|
|
# processes sync via the store. If we initialize autograd after RPC,
|
|
# there could be a race where some nodes might have initialized autograd
|
|
# and others might not have. As a result, a node calling
|
|
# torch.distributed.autograd.backward() would run into errors since
|
|
# other nodes might not have been initialized.
|
|
dist_autograd._init(rank)
|
|
|
|
_set_profiler_node_id(rank)
|
|
# Initialize RPC.
|
|
api._init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
|
|
|
|
|
|
@api._require_initialized
|
|
def _get_debug_info():
|
|
from . import _rref_context_get_debug_info
|
|
info = _rref_context_get_debug_info()
|
|
info.update(api._get_current_rpc_agent().get_debug_info())
|
|
info.update(dist_autograd._get_debug_info())
|
|
return info
|