mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Features: * sync and async RPC for builtin operators * RpcAgent API * ProcessGroupAgent implementation Goal: * have a minimum working and testable RPC implementation * make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object. * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...). * support blocking and non-blocking RequestCallback * blocking means the callback won't return before sending out the response * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list. We are not exporting this diff until we finalize distributed autograd design and publish the API review publicly. https://fb.quip.com/FabTAZKVgQpf Pull Request resolved: https://github.com/pytorch/pytorch/pull/23228 ghstack-source-id: 87816717 Reviewed By: zhaojuanmao Differential Revision: D15194693 fbshipit-source-id: 7adb600796613cde6073db6c227451b89940ecaf
175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
from . import invoke_rpc
|
|
from . import ProcessGroupAgent
|
|
|
|
import array
|
|
import sys
|
|
import torch
|
|
|
|
|
|
_agent = None
|
|
|
|
|
|
def _collect_worker_names(name, group):
|
|
from . import all_gather
|
|
from . import get_world_size
|
|
|
|
# collect name length
|
|
ws = get_world_size(group)
|
|
name_bytes = name if sys.version_info < (3, 0) else bytes(name, 'utf8')
|
|
name_bytes = list(array.array('B', name_bytes))
|
|
name_len = len(name_bytes)
|
|
len_input = torch.ones(1, dtype=torch.int64) * name_len
|
|
len_outputs = [torch.empty(1, dtype=torch.int64) for _ in range(ws)]
|
|
all_gather(len_outputs, len_input, group=group)
|
|
|
|
# collect name value
|
|
max_len = torch.stack(len_outputs).max().item()
|
|
name_input = torch.empty(max_len, dtype=torch.uint8)
|
|
name_input[:name_len] = torch.tensor(name_bytes, dtype=torch.uint8)
|
|
name_outputs = [torch.empty(max_len, dtype=torch.uint8) for _ in range(ws)]
|
|
all_gather(name_outputs, name_input, group=group)
|
|
|
|
names = []
|
|
for i in range(ws):
|
|
name_tensor = name_outputs[i][:len_outputs[i]]
|
|
names.append(bytearray(name_tensor.tolist()).decode('utf8'))
|
|
|
|
return names
|
|
|
|
|
|
def join_rpc():
|
|
r"""
|
|
Block until all local and remote RPC processes reach this method, process
|
|
(send and receive) all pending messages, and then destroy local RPC agent.
|
|
Every RPC process must call this method before exit.
|
|
"""
|
|
global _agent
|
|
|
|
if _agent:
|
|
_agent.join()
|
|
_agent = None
|
|
|
|
|
|
def sync_rpc():
|
|
r"""
|
|
Block until all local and remote RPC processes reach this method and finish
|
|
sending all pending RPCs. As this method synchronizes at the process
|
|
level, if multiple threads are spawned, only one of them should call this
|
|
method at a time.
|
|
"""
|
|
if _agent is None:
|
|
raise RuntimeError("RPC has not been initialized. "
|
|
"Call init_rpc(name) first.")
|
|
|
|
_agent.sync()
|
|
|
|
|
|
# TODO: add a context managet to wrap init_rpc and join_rpc
|
|
def init_rpc(name, backend='pg'):
|
|
r"""
|
|
Initialize the local RPC agent which immediately makes the current process
|
|
ready to send and receive RPCs. The caller needs to make sure the specified
|
|
backend is properly intialized before calling this method. For example, to
|
|
use ``pg`` (ProcessGroup) backend, ``init_process_group`` must be invoked
|
|
prior to this method.
|
|
|
|
Arguments:
|
|
name (str): a globally unique name of the local RPC agent. (e.g.,
|
|
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
|
|
backend (str): type of RPC backend implementation. Currently,
|
|
process group backend ``"pg"`` is the only available
|
|
backend implementation. (default: ``"pg"``).
|
|
"""
|
|
global _agent
|
|
|
|
if _agent:
|
|
raise RuntimeError("RPC is already initialized")
|
|
|
|
if backend == 'pg':
|
|
from .distributed_c10d import _get_default_group
|
|
group = _get_default_group()
|
|
# TODO: issue #23232
|
|
names = _collect_worker_names(name, group)
|
|
name_dict = {names[r] : r for r in range(len(names))}
|
|
_agent = ProcessGroupAgent(name, name_dict, group)
|
|
else:
|
|
raise RuntimeError("Unrecognized RPC backend ", backend)
|
|
|
|
|
|
def rpc(to, func, args=None, kwargs=None, async_call=False):
|
|
r"""
|
|
Make an RPC call to run function ``func`` on worker ``to``. By default, it
|
|
blocks until the return value is locally available. RPC messages are sent
|
|
and received in parallel to execution of Python code. This method is
|
|
thread-safe.
|
|
|
|
Arguments:
|
|
to (str): name of the destination worker.
|
|
func (callable): a builtin function (e.g., ``torch.add``).
|
|
args (tuple): the argument tuple for the ``func`` invocation.
|
|
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
|
invocation.
|
|
async_call (bool): If set to ``True``, this will be an asynchronous RPC,
|
|
and returns a ``torch.distributed.FutureMessage``
|
|
object immediately. Otherwise, this RPC will block
|
|
until the return value is locally available.
|
|
(default: ``False``)
|
|
|
|
Returns:
|
|
If ``async_call`` is ``False``, returns the result of running ``func``
|
|
on ``args`` and ``kwargs``. If ``async_call`` is ``True``, returns a
|
|
``torch.distributed.FutureMessage`` object that can be waited on. When
|
|
completed, the return value of ``func`` on ``args`` and ``kwargs`` can
|
|
be retrieved from the ``FutureMessage`` object.
|
|
|
|
Example::
|
|
|
|
Synchronous example:
|
|
|
|
On worker 0:
|
|
>>> import torch.distributed as dist
|
|
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
|
>>> dist.init_rpc("worker0")
|
|
>>> ret = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3))
|
|
>>> dist.join_rpc()
|
|
|
|
One worker 1:
|
|
>>> import torch.distributed as dist
|
|
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
|
>>> dist.init_rpc("worker1")
|
|
>>> dist.join_rpc()
|
|
|
|
Asynchronous example:
|
|
|
|
On worker 0:
|
|
>>> import torch.distributed as dist
|
|
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
|
>>> dist.init_rpc("worker0")
|
|
>>> fut1 = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3), async_call=True)
|
|
>>> fut2 = dist.rpc("worker1", torch.add, args=(torch.ones(2), 2), async_call=True)
|
|
>>> result = fut1.wait() + fut2.wait()
|
|
>>> dist.join_rpc()
|
|
|
|
One worker 1:
|
|
>>> import torch.distributed as dist
|
|
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
|
>>> dist.init_rpc("worker1")
|
|
>>> dist.join_rpc()
|
|
"""
|
|
if _agent is None:
|
|
raise RuntimeError("RPC has not been initialized. "
|
|
"Call init_rpc(name) first.")
|
|
|
|
qualified_name = torch.jit._find_builtin(func)
|
|
if qualified_name is None:
|
|
raise RuntimeError("unknown builtin function %s." % func)
|
|
|
|
args = args if args else ()
|
|
kwargs = kwargs if kwargs else {}
|
|
fut = invoke_rpc(_agent, to, qualified_name, *args, **kwargs)
|
|
|
|
if async_call:
|
|
return fut
|
|
else:
|
|
return fut.wait()
|