Move RPC API to torch.distributed.rpc

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27290

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D17808212

Pulled By: pietern

fbshipit-source-id: c79907940fe4888b2ceaaa1cda0078e39c89b454
This commit is contained in:
Pieter Noordhuis 2019-10-08 11:22:18 -07:00 committed by Facebook Github Bot
parent a6d26ce135
commit b4ce922b58
6 changed files with 142 additions and 130 deletions

View File

@ -4,8 +4,8 @@ import time
import unittest
import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from dist_utils import INIT_METHOD_TEMPLATE, dist_init
@ -66,8 +66,8 @@ class DistAutogradTest(object):
with dist_autograd.context() as context_id:
t1 = torch.ones(3, 3, requires_grad=True)
t2 = torch.zeros(3, 3, requires_grad=True)
ret = dist.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
dist.rpc_sync(
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
rpc.rpc_sync(
"worker{}".format(dst_rank), _set_rpc_done, args=(context_id,)
)
@ -142,7 +142,7 @@ class DistAutogradTest(object):
tensors = []
for i in range(num_tensors):
tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), torch.stack, args=(tensors,)
)
self.assertEqual(torch.stack(tensors), ret)

View File

@ -4,7 +4,8 @@ from functools import wraps
from os import getenv
import torch.distributed as dist
from torch.distributed.rpc_api import RpcBackend
import torch.distributed.rpc as rpc
from torch.distributed.rpc.api import RpcBackend
if not dist.is_available():
@ -37,13 +38,13 @@ def dist_init(test_method):
def wrapper(self, *arg, **kwargs):
self.worker_id = self.rank
dist.init_process_group(backend="gloo", init_method=self.init_method)
dist.init_model_parallel(
rpc.init_model_parallel(
self_name="worker%d" % self.rank,
backend=TEST_CONFIG.backend,
self_rank=self.rank,
init_method=self.init_method,
)
test_method(self, *arg, **kwargs)
dist.join_rpc()
rpc.join_rpc()
return wrapper

View File

@ -13,12 +13,12 @@ import torch.distributed.rpc as rpc
from common_utils import load_tests
from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init
from torch.distributed import ProcessGroupAgent
from torch.distributed.rpc import RpcBackend
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
from torch.distributed.rpc_api import RpcBackend
def requires_process_group_agent(func):
from torch.distributed.rpc_api import _agent
from torch.distributed.rpc.api import _agent
return unittest.skipUnless(
isinstance(_agent, ProcessGroupAgent),
@ -118,7 +118,7 @@ def no_result():
def nested_rpc(dst):
return dist.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
def multi_layer_nested_async_rpc(dst, world_size, ttl):
@ -127,7 +127,7 @@ def multi_layer_nested_async_rpc(dst, world_size, ttl):
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
dist.rpc_async(
rpc.rpc_async(
current_dst,
multi_layer_nested_async_rpc,
args=(next_dst, world_size, ttl - 1),
@ -202,32 +202,32 @@ class RpcTest(object):
def test_worker_id(self):
n = self.rank + 1
peer_rank = n % self.world_size
self_worker_info = dist.get_worker_info()
peer_worker_info = dist.get_worker_info("worker{}".format(peer_rank))
self_worker_info = rpc.get_worker_info()
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
unknown_worker_id = dist.get_worker_info("WorkerUnknown")
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
@dist_init
def test_self_add(self):
self_worker_info = dist.get_worker_info()
self_worker_info = rpc.get_worker_info()
self_worker_name = "worker{}".format(self.rank)
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
dist.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
dist.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
@mock.patch.object(torch.distributed.autograd, "_init")
@mock.patch.object(torch.distributed.rpc_api, "_init_rref_context")
@mock.patch.object(torch.distributed.rpc.api, "_init_rref_context")
def test_register_rpc_backend_and_init_rpc_backend(
self, mock_init_rref_context, mock_dist_autograd_init
):
@ -235,7 +235,7 @@ class RpcTest(object):
rpc.register_backend(
backend_name, stub_init_rpc_backend_handler
)
dist.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
rpc.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
@unittest.skipIf(
TEST_CONFIG.backend != RpcBackend.PROCESS_GROUP,
@ -244,34 +244,34 @@ class RpcTest(object):
def test_duplicate_name(self):
dist.init_process_group(backend="gloo", init_method=self.init_method)
with self.assertRaisesRegex(RuntimeError, "is not unique"):
dist.init_model_parallel(
rpc.init_model_parallel(
self_name="duplicate_name",
backend=TEST_CONFIG.backend,
self_rank=self.rank,
init_method=self.init_method,
)
dist.join_rpc()
rpc.join_rpc()
def test_reinit(self):
dist.init_process_group(backend="gloo", init_method=self.init_method)
dist.init_model_parallel(
rpc.init_model_parallel(
self_name="worker{}".format(self.rank),
backend=TEST_CONFIG.backend,
self_rank=self.rank,
init_method=self.init_method,
)
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
dist.init_model_parallel(
rpc.init_model_parallel(
self_name="worker{}".format(self.rank),
backend=TEST_CONFIG.backend,
self_rank=self.rank,
init_method=self.init_method,
)
dist.join_rpc()
rpc.join_rpc()
def test_init_invalid_backend(self):
with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"):
dist.init_model_parallel(
rpc.init_model_parallel(
self_name="worker{}".format(self.rank),
backend="invalid",
self_rank=self.rank,
@ -283,30 +283,30 @@ class RpcTest(object):
dist.init_process_group(backend="gloo", init_method=self.init_method)
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
dist.init_model_parallel(self_name="abc*")
rpc.init_model_parallel(self_name="abc*")
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
dist.init_model_parallel(self_name=" ")
rpc.init_model_parallel(self_name=" ")
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
dist.init_model_parallel(self_name="")
rpc.init_model_parallel(self_name="")
# If the number in the message does not match, it is likely that the
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
dist.init_model_parallel(
rpc.init_model_parallel(
self_name="".join(["a" for _ in range(500)]),
backend=TEST_CONFIG.backend,
self_rank=self.rank,
init_method=self.init_method,
)
dist.join_rpc()
rpc.join_rpc()
@dist_init
def test_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
@ -317,9 +317,9 @@ class RpcTest(object):
def test_add_with_id(self):
n = self.rank + 1
dst_rank = n % self.world_size
workder_info = dist.get_worker_info("worker{}".format(dst_rank))
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@ -328,7 +328,7 @@ class RpcTest(object):
def test_scalar_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
)
self.assertEqual(ret, (torch.ones(n, n) + n))
@ -337,7 +337,7 @@ class RpcTest(object):
def test_async_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = dist.rpc_async(
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
@ -350,7 +350,7 @@ class RpcTest(object):
dst_rank = n % self.world_size
x = torch.ones(self.world_size, self.world_size)
x[self.rank][self.rank] = 0
ret = dist.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
@dist_init
@ -358,7 +358,7 @@ class RpcTest(object):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
@ -369,18 +369,18 @@ class RpcTest(object):
def test_sync_rpc(self):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
dist.sync_rpc()
rpc.sync_rpc()
n = i + self.rank + 1
ret1 = dist.rpc_sync(
ret1 = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
dist.sync_rpc()
ret2 = dist.rpc_sync(
rpc.sync_rpc()
ret2 = rpc.rpc_sync(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
dist.sync_rpc()
rpc.sync_rpc()
self.assertEqual(ret1, torch.ones(n, n) * 2)
self.assertEqual(ret2, torch.ones(n, n) * 3)
@ -388,29 +388,29 @@ class RpcTest(object):
def test_join_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
dist.join_rpc()
rpc.join_rpc()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
dist.rpc_sync(
rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# it's safe to call join_rpc() multiple times
dist.join_rpc()
rpc.join_rpc()
@dist_init
def test_expected_src(self):
dst_rank = (self.rank + 1) % self.world_size
expected_src_rank = (self.rank - 1) % self.world_size
ret = dist.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
value = VALUE_FUTURE.result()
self.assertEqual(value, expected_src_rank)
@ -418,14 +418,14 @@ class RpcTest(object):
def test_py_built_in(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
self.assertEqual(ret, min(n, n + 1, n + 2))
@dist_init
def test_py_user_defined(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
@ -436,14 +436,14 @@ class RpcTest(object):
def test_py_class_constructor(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
self.assertEqual(ret.a, n)
@dist_init
def test_py_class_instance_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
)
self.assertEqual(ret, MyClass(2).my_instance_method(n))
@ -452,7 +452,7 @@ class RpcTest(object):
def test_py_class_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
)
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
@ -461,7 +461,7 @@ class RpcTest(object):
def test_py_class_static_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
)
self.assertEqual(ret, MyClass.my_static_method(n + 10))
@ -470,9 +470,9 @@ class RpcTest(object):
def test_py_multi_async_call(self):
n = self.rank + 1
dst_rank = n % self.world_size
dst_worker_info = dist.get_worker_info("worker{}".format(dst_rank))
fut1 = dist.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
fut2 = dist.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
@ -480,14 +480,14 @@ class RpcTest(object):
def test_py_no_return_result(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync("worker{}".format(dst_rank), no_result)
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
self.assertEqual(ret, no_result())
@dist_init
def test_py_tensors(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(n, n), torch.ones(n, n)),
@ -500,7 +500,7 @@ class RpcTest(object):
n = self.rank + 1
dst_rank = n % self.world_size
for i in range(100):
fut = dist.rpc_async(
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(i, i), torch.ones(i, i)),
@ -521,7 +521,7 @@ class RpcTest(object):
a = [torch.ones(n, n), torch.ones(n, n)]
b = TensorClass(build_complex_tensors())
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
)
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
@ -531,7 +531,7 @@ class RpcTest(object):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
run_nested_pickle,
args=(MyPickleClass(), torch.ones(2, 2)),
@ -546,13 +546,13 @@ class RpcTest(object):
n = self.rank + 1
dst_rank = n % self.world_size
with self.assertRaisesRegex(Exception, "TypeError"):
ret = dist.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
@dist_init
def test_py_raise_in_user_func(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = dist.rpc_async("worker{}".format(dst_rank), raise_func)
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
with self.assertRaisesRegex(Exception, "ValueError"):
fut.wait()
@ -560,7 +560,7 @@ class RpcTest(object):
def test_nested_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc_sync(
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
nested_rpc,
args=("worker{}".format(self.rank),),
@ -575,7 +575,7 @@ class RpcTest(object):
futs = []
tik = time.time()
for _ in range(repeat):
fut = dist.rpc_async("worker{}".format(dst_rank), f, args=args)
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
futs.append(fut)
for fut in futs:
@ -599,7 +599,7 @@ class RpcTest(object):
def test_builtin_remote_ret(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote(
rref = rpc.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
@ -615,7 +615,7 @@ class RpcTest(object):
for i in range(m):
n = n + i
rrefs.append(
dist.remote(
rpc.remote(
"worker{}".format(dst_rank),
fn,
args=args_fn(n),

View File

@ -1,7 +1,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import sys
def is_available():
@ -19,39 +18,3 @@ if is_available():
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import _backend # noqa: F401
if sys.version_info >= (3, 0):
from .rpc_api import _init_rpc
from .rpc_api import * # noqa: F401
def init_model_parallel(self_name,
backend=RpcBackend.PROCESS_GROUP,
self_rank=-1,
init_method=None,
num_send_recv_threads=4):
r"""
Initializes model parallel primitives such as the local rpc agent
and distributed autograd.
Initializes the local RPC agent which immediately makes the current
process ready to send and receive RPCs. The caller needs to make
sure the specified backend is properly intialized before calling
this method. For example, to use ``pg`` (ProcessGroup) backend,
``init_process_group`` must be invoked prior to this method.
Arguments:
backend (Enum): type of RPC backend implementation.
Currently, process group backend is the only
available backend implementation. (default:
``RpcBackend.PROCESS_GROUP``).
self_name (str): a globally unique name of this node. (e.g.,
``Trainer3``, ``ParameterServer2``, ``Master``,
``Worker1``) Name can only contain number, alphabet,
underscore, and/or dash, and must be shorter than
128 characters.
self_rank (int): a globally unique id/rank of this node.
init_method(str): backend specific init arguments.
num_send_recv_threads(int): Number of threads for send/recv work.
"""
_init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads)
from .rpc_api import _agent
autograd._init(_agent.get_worker_info().id)

View File

@ -1 +1,44 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
from .backend_registry import * # noqa: F401
if sys.version_info >= (3, 0):
from .api import _init_rpc
from .api import * # noqa: F401
import torch.distributed.autograd
def init_model_parallel(self_name,
backend=RpcBackend.PROCESS_GROUP,
self_rank=-1,
init_method=None,
num_send_recv_threads=4):
r"""
Initializes model parallel primitives such as the local rpc agent
and distributed autograd.
Initializes the local RPC agent which immediately makes the current
process ready to send and receive RPCs. The caller needs to make
sure the specified backend is properly intialized before calling
this method. For example, to use ``pg`` (ProcessGroup) backend,
``init_process_group`` must be invoked prior to this method.
Arguments:
backend (Enum): type of RPC backend implementation.
Currently, process group backend is the only
available backend implementation. (default:
``RpcBackend.PROCESS_GROUP``).
self_name (str): a globally unique name of this node. (e.g.,
``Trainer3``, ``ParameterServer2``, ``Master``,
``Worker1``) Name can only contain number, alphabet,
underscore, and/or dash, and must be shorter than
128 characters.
self_rank (int): a globally unique id/rank of this node.
init_method(str): backend specific init arguments.
num_send_recv_threads(int): Number of threads for send/recv work.
"""
_init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads)
from .api import _agent
torch.distributed.autograd._init(_agent.get_worker_info().id)

View File

@ -1,10 +1,10 @@
from . import invoke_rpc_builtin, invoke_rpc_python_udf
from . import invoke_remote_builtin, invoke_remote_python_udf
from . import _init_rref_context, _destroy_rref_context
from . import ProcessGroupAgent
from . import WorkerInfo
from .rpc import is_backend_registered, init_backend
from .rpc.internal import _internal_rpc_pickler, PythonUDF
from torch.distributed import invoke_rpc_builtin, invoke_rpc_python_udf
from torch.distributed import invoke_remote_builtin, invoke_remote_python_udf
from torch.distributed import _init_rref_context, _destroy_rref_context
from torch.distributed import ProcessGroupAgent
from torch.distributed import WorkerInfo
from .backend_registry import is_backend_registered, init_backend
from .internal import _internal_rpc_pickler, PythonUDF
import functools
import sys
@ -69,7 +69,7 @@ def _init_rpc(backend=RpcBackend.PROCESS_GROUP,
raise RuntimeError("RPC is already initialized")
if backend == RpcBackend.PROCESS_GROUP:
from .distributed_c10d import _get_default_group
from torch.distributed.distributed_c10d import _get_default_group
group = _get_default_group()
if (self_rank != -1) and (self_rank != group.rank()):
@ -140,19 +140,20 @@ def remote(to, func, args=None, kwargs=None):
On worker 0:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_rpc("worker0")
>>> worker1 = dist.get_worker_info("worker1")
>>> rref1 = dist.remote(worker1, torch.add, args=(torch.ones(2), 3))
>>> rref2 = dist.remote(worker1, torch.add, args=(torch.ones(2), 1))
>>> rpc.init_rpc("worker0")
>>> worker1 = rpc.get_worker_info("worker1")
>>> rref1 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
>>> dist.join_rpc()
>>> rpc.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_rpc("worker1")
>>> dist.join_rpc()
>>> rpc.join_rpc()
"""
qualified_name = torch.jit._find_builtin(func)
@ -213,16 +214,18 @@ def rpc_sync(to, func, args=None, kwargs=None):
Example::
On worker 0:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> ret = dist.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> dist.join_rpc()
>>> rpc.init_model_parallel("worker0")
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.join_rpc()
>>> rpc.init_model_parallel("worker1")
>>> rpc.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut.wait()
@ -253,19 +256,21 @@ def rpc_async(to, func, args=None, kwargs=None):
On worker 0:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> worker1 = dist.get_worker_id("worker1")
>>> fut1 = dist.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
>>> fut2 = dist.rpc_async(worker1, min, args=(1, 2))
>>> rpc.init_model_parallel("worker0")
>>> worker1 = rpc.get_worker_id("worker1")
>>> fut1 = rpc.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async(worker1, min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> dist.join_rpc()
>>> rpc.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.join_rpc()
>>> rpc.init_model_parallel("worker1")
>>> rpc.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut