mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[TensorPipe] Update documentation (#40222)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40222 Mention the TensorPipe agent in the RPC docs and give users the information they need to choose which agent to use. ghstack-source-id: 106225711 Test Plan: Export to GitHub, build locally and try out the docs. Differential Revision: D22116494 fbshipit-source-id: 30703ba8410c40f64e785f60d71dfd9faa8de4a1
This commit is contained in:
parent
8315bb2359
commit
2393bab036
|
|
@ -82,10 +82,7 @@ RPC
|
|||
Before using RPC and distributed autograd primitives, initialization must take
|
||||
place. To initialize the RPC framework we need to use
|
||||
:meth:`~torch.distributed.rpc.init_rpc` which would initialize the RPC
|
||||
framework, RRef framework and distributed autograd. By default, this will also
|
||||
initialize the ``ProcessGroup`` (:meth:`~torch.distributed.init_process_group`)
|
||||
backend for RPC communication. The ``ProcessGroup`` backend internally uses gloo
|
||||
for communication.
|
||||
framework, RRef framework and distributed autograd.
|
||||
|
||||
.. automodule:: torch.distributed.rpc
|
||||
.. autofunction:: init_rpc
|
||||
|
|
@ -109,9 +106,6 @@ and move it to the desired devices on the callee if necessary.
|
|||
.. autofunction:: shutdown
|
||||
.. autoclass:: WorkerInfo
|
||||
:members:
|
||||
.. autoclass:: ProcessGroupRpcBackendOptions
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
|
||||
The RPC package also provides decorators which allow applications to specify
|
||||
|
|
@ -122,8 +116,124 @@ how a given function should be treated on the callee side.
|
|||
|
||||
.. autofunction:: torch.distributed.rpc.functions.async_execution
|
||||
|
||||
.. _rref:
|
||||
|
||||
.. _rpc-backends:
|
||||
|
||||
Backends
|
||||
^^^^^^^^
|
||||
|
||||
The RPC module can leverage different backends to perform the communication
|
||||
between the nodes. The backend to be used can be specified in the
|
||||
:func:`~torch.distributed.rpc.init_rpc` function, by passing a certain value of
|
||||
the :class:`~torch.distributed.rpc.BackendType` enum. Regardless of what backend
|
||||
is used, the rest of the RPC API won't change. Each backend also defines its own
|
||||
subclass of the :class:`~torch.distributed.rpc.RpcBackendOptions` class, an
|
||||
instance of which can also be passed to :func:`~torch.distributed.rpc.init_rpc`
|
||||
to configure the backend's behavior.
|
||||
|
||||
.. autoclass:: BackendType
|
||||
|
||||
.. autoclass:: RpcBackendOptions
|
||||
:members:
|
||||
|
||||
|
||||
Process Group Backend
|
||||
"""""""""""""""""""""
|
||||
|
||||
The Process Group agent, which is the default, instantiates a process group from
|
||||
the :mod:`~torch.distributed` module and utilizes its point-to-point
|
||||
communication capabilities to send RPC messages across. Internally, the process
|
||||
group uses `the Gloo library <https://github.com/facebookincubator/gloo/>`_.
|
||||
|
||||
Gloo has been hardened by years of extensive use in PyTorch and is thus very
|
||||
reliable. However, as it was designed to perform collective communication, it
|
||||
may not always be the best fit for RPC. For example, each networking operation
|
||||
is synchronous and blocking, which means that it cannot be run in parallel with
|
||||
others. Moreover, it opens a connection between all pairs of nodes, and brings
|
||||
down all of them when one fails, thus reducing the resiliency and the elasticity
|
||||
of the system.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import os
|
||||
>>> from torch.distributed import rpc
|
||||
>>> os.environ['MASTER_ADDR'] = 'localhost'
|
||||
>>> os.environ['MASTER_PORT'] = '29500'
|
||||
>>>
|
||||
>>> rpc.init_rpc(
|
||||
>>> "worker1",
|
||||
>>> rank=0,
|
||||
>>> world_size=2,
|
||||
>>> rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
|
||||
>>> num_send_recv_threads=16,
|
||||
>>> rpc_timeout=20 # 20 second timeout
|
||||
>>> )
|
||||
>>> )
|
||||
>>>
|
||||
>>> # omitting init_rpc invocation on worker2
|
||||
|
||||
|
||||
.. autoclass:: ProcessGroupRpcBackendOptions
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
|
||||
TensorPipe Backend
|
||||
""""""""""""""""""
|
||||
|
||||
.. warning::
|
||||
The TensorPipe backend is a **beta feature**.
|
||||
|
||||
The TensorPipe agent leverages `the TensorPipe library
|
||||
<https://github.com/pytorch/tensorpipe>`_, which provides a natively
|
||||
point-to-point communication primitive specifically suited for machine learning
|
||||
that fundamentally addresses some of the limitations of Gloo. Compared to Gloo,
|
||||
it has the advantage of being asynchronous, which allows a large number of
|
||||
transfers to occur simultaneously, each at their own speed, without blocking
|
||||
each other. It will only open pipes between pairs of nodes when needed, on
|
||||
demand, and when one node fails only its incident pipes will be closed, while
|
||||
all other ones will keep working as normal. In addition, it is able to support
|
||||
multiple different transports (TCP, of course, but also shared memory, NVLink,
|
||||
InfiniBand, ...) and can automatically detect their availability and negotiate
|
||||
the best transport to use for each pipe.
|
||||
|
||||
The TensorPipe backend has been introduced in PyTorch v1.6 and is being actively
|
||||
developed. At the moment, it only supports CPU tensors, with GPU support coming
|
||||
soon. It comes with a TCP-based transport, just like Gloo. It is also able to
|
||||
automatically chunk and multiplex large tensors over multiple sockets and
|
||||
threads in order to achieve very high bandwidths. In addition to that, it packs
|
||||
two Linux-specific transports for communication between processes on a same
|
||||
machine (one based on ringbuffers stored in shared memory, the other on the
|
||||
cross-memory attach syscalls) which can achieve lower latencies than TCP.
|
||||
The agent will be able to pick the best transport on its own, with no
|
||||
intervention required.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import os
|
||||
>>> from torch.distributed import rpc
|
||||
>>> os.environ['MASTER_ADDR'] = 'localhost'
|
||||
>>> os.environ['MASTER_PORT'] = '29500'
|
||||
>>>
|
||||
>>> rpc.init_rpc(
|
||||
>>> "worker1",
|
||||
>>> rank=0,
|
||||
>>> world_size=2,
|
||||
>>> backend=rpc.BackendType.TENSORPIPE,
|
||||
>>> rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
|
||||
>>> num_worker_threads=8,
|
||||
>>> rpc_timeout=20 # 20 second timeout
|
||||
>>> )
|
||||
>>> )
|
||||
>>>
|
||||
>>> # omitting init_rpc invocation on worker2
|
||||
|
||||
.. autoclass:: TensorPipeRpcBackendOptions
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
|
||||
.. _rref:
|
||||
|
||||
RRef
|
||||
----
|
||||
|
|
|
|||
|
|
@ -397,25 +397,6 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
|||
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
|
||||
init_method (str, optional): The URL to initialize
|
||||
``ProcessGroupGloo`` (default: ``env://``).
|
||||
|
||||
|
||||
Example::
|
||||
>>> import datetime, os
|
||||
>>> from torch.distributed import rpc
|
||||
>>> os.environ['MASTER_ADDR'] = 'localhost'
|
||||
>>> os.environ['MASTER_PORT'] = '29500'
|
||||
>>>
|
||||
>>> rpc.init_rpc(
|
||||
>>> "worker1",
|
||||
>>> rank=0,
|
||||
>>> world_size=2,
|
||||
>>> rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
|
||||
>>> num_send_recv_threads=16,
|
||||
>>> rpc_timeout=20 # 20 second timeout
|
||||
>>> )
|
||||
>>> )
|
||||
>>>
|
||||
>>> # omitting init_rpc invocation on worker2
|
||||
)")
|
||||
.def(
|
||||
py::init<int, float, std::string>(),
|
||||
|
|
@ -473,7 +454,30 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
|||
|
||||
// Base class: torch.distributed.rpc.RpcBackendOptions.
|
||||
py::class_<TensorPipeRpcBackendOptions>(
|
||||
module, "TensorPipeRpcBackendOptions", rpcBackendOptions)
|
||||
module,
|
||||
"TensorPipeRpcBackendOptions",
|
||||
rpcBackendOptions,
|
||||
R"(
|
||||
The backend options for
|
||||
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
|
||||
:class:`~torch.distributed.rpc.RpcBackendOptions`.
|
||||
|
||||
Arguments:
|
||||
num_worker_threads (int, optional): The number of threads in the
|
||||
thread-pool used by
|
||||
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
|
||||
requests (default: 16).
|
||||
rpc_timeout (float, optional): The default timeout, in seconds,
|
||||
for RPC requests (default: 60 seconds). If the RPC has not
|
||||
completed in this timeframe, an exception indicating so will
|
||||
be raised. Callers can override this timeout for individual
|
||||
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
|
||||
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
|
||||
init_method (str, optional): The URL to initialize the distributed
|
||||
store used for rendezvous. It takes any value accepted for the
|
||||
same argument of :meth:`~torch.distributed.init_process_group`
|
||||
(default: ``env://``).
|
||||
)")
|
||||
.def(
|
||||
py::init<
|
||||
int,
|
||||
|
|
@ -487,7 +491,13 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
|||
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
||||
py::arg("init_method") = kDefaultInitMethod)
|
||||
.def_readwrite(
|
||||
"num_worker_threads", &TensorPipeRpcBackendOptions::numWorkerThreads);
|
||||
"num_worker_threads",
|
||||
&TensorPipeRpcBackendOptions::numWorkerThreads,
|
||||
R"(
|
||||
The number of threads in the thread-pool used by
|
||||
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
|
||||
requests.
|
||||
)");
|
||||
|
||||
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
|
||||
py::cast(kDefaultNumWorkerThreads);
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ if is_available() and not torch._C._rpc_init():
|
|||
if is_available():
|
||||
from . import api, backend_registry, functions, _set_profiler_node_id
|
||||
from .api import * # noqa: F401
|
||||
from .backend_registry import BackendType
|
||||
from .server_process_global_profiler import (
|
||||
_server_process_global_profile,
|
||||
)
|
||||
|
|
@ -25,7 +26,7 @@ if is_available():
|
|||
|
||||
def init_rpc(
|
||||
name,
|
||||
backend=backend_registry.BackendType.PROCESS_GROUP,
|
||||
backend=BackendType.PROCESS_GROUP,
|
||||
rank=-1,
|
||||
world_size=None,
|
||||
rpc_backend_options=None,
|
||||
|
|
@ -38,27 +39,28 @@ if is_available():
|
|||
process ready to send and receive RPCs.
|
||||
|
||||
Arguments:
|
||||
backend (Enum): type of RPC backend implementation. Currently,
|
||||
process group backend is the only available backend
|
||||
implementation. (default: ``RpcBackend.PROCESS_GROUP``).
|
||||
backend (BackendType, optional): The type of RPC backend
|
||||
implementation. Supported values include
|
||||
``BackendType.PROCESS_GROUP`` (the default) and
|
||||
``BackendType.TENSORPIPE``. See :ref:`rpc-backends` for more
|
||||
information.
|
||||
name (str): a globally unique name of this node. (e.g.,
|
||||
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
|
||||
Name can only contain number, alphabet, underscore, and/or dash,
|
||||
and must be shorter than 128 characters.
|
||||
rank (int): a globally unique id/rank of this node.
|
||||
world_size (int): The number of workers in the group.
|
||||
rpc_backend_options (RpcBackendOptions): The options passed to
|
||||
RpcAgent constructor. It contains RpcAgent specific
|
||||
initialization configurations. By default, it contains
|
||||
``rpc_timeout = timedelta(seconds=60)``,
|
||||
``init_method = "env://"``, ``num_send_recv_threads = 4`` for
|
||||
process group agent. If using the default
|
||||
``rpc_backend_options``, RPC would initialize the underlying
|
||||
process group backend using ``init_method = "env://"``,
|
||||
rpc_backend_options (RpcBackendOptions, optional): The options
|
||||
passed to the RpcAgent constructor. It must be an agent-specific
|
||||
subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
|
||||
and contains agent-specific initialization configurations. By
|
||||
default, for all agents, it sets the default timeout to 60
|
||||
seconds and performs the rendezvous with an underlying process
|
||||
group initialized using ``init_method = "env://"``,
|
||||
meaning that environment variables ``MASTER_ADDR`` and
|
||||
``MASTER_PORT`` needs to be set properly. See
|
||||
:class:`~torch.distributed.rpc.ProcessGroupRpcBackendOptions`
|
||||
for examples.
|
||||
:ref:`rpc-backends` for more information and find which options
|
||||
are available.
|
||||
"""
|
||||
|
||||
if not rpc_backend_options:
|
||||
|
|
|
|||
|
|
@ -18,9 +18,18 @@ def _backend_type_repr(self):
|
|||
return "BackendType." + self.name
|
||||
|
||||
|
||||
_backend_type_doc = """
|
||||
An enum class of available backends.
|
||||
|
||||
PyTorch ships with two builtin backends: ``BackendType.PROCESS_GROUP`` and
|
||||
``BackendType.TENSORPIPE``. Additional ones can be registered using the
|
||||
:func:`~torch.distributed.rpc.backend_registry.register_backend` function.
|
||||
"""
|
||||
|
||||
# Create an enum type, `BackendType`, with empty members.
|
||||
BackendType = enum.Enum(value="BackendType", names={})
|
||||
BackendType.__repr__ = _backend_type_repr
|
||||
BackendType.__doc__ = _backend_type_doc
|
||||
|
||||
def backend_registered(backend_name):
|
||||
"""
|
||||
|
|
@ -65,6 +74,7 @@ def register_backend(
|
|||
)
|
||||
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)
|
||||
BackendType.__repr__ = _backend_type_repr
|
||||
BackendType.__doc__ = _backend_type_doc
|
||||
return BackendType[backend_name]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user