mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Move RPC API to torch.distributed.rpc
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27290 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D17808212 Pulled By: pietern fbshipit-source-id: c79907940fe4888b2ceaaa1cda0078e39c89b454
This commit is contained in:
parent
a6d26ce135
commit
b4ce922b58
|
|
@ -4,8 +4,8 @@ import time
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
import torch.distributed.rpc as rpc
|
||||
from dist_utils import INIT_METHOD_TEMPLATE, dist_init
|
||||
|
||||
|
||||
|
|
@ -66,8 +66,8 @@ class DistAutogradTest(object):
|
|||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.ones(3, 3, requires_grad=True)
|
||||
t2 = torch.zeros(3, 3, requires_grad=True)
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
|
||||
dist.rpc_sync(
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
|
||||
rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), _set_rpc_done, args=(context_id,)
|
||||
)
|
||||
|
||||
|
|
@ -142,7 +142,7 @@ class DistAutogradTest(object):
|
|||
tensors = []
|
||||
for i in range(num_tensors):
|
||||
tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), torch.stack, args=(tensors,)
|
||||
)
|
||||
self.assertEqual(torch.stack(tensors), ret)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from functools import wraps
|
|||
from os import getenv
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.rpc_api import RpcBackend
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.distributed.rpc.api import RpcBackend
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
|
|
@ -37,13 +38,13 @@ def dist_init(test_method):
|
|||
def wrapper(self, *arg, **kwargs):
|
||||
self.worker_id = self.rank
|
||||
dist.init_process_group(backend="gloo", init_method=self.init_method)
|
||||
dist.init_model_parallel(
|
||||
rpc.init_model_parallel(
|
||||
self_name="worker%d" % self.rank,
|
||||
backend=TEST_CONFIG.backend,
|
||||
self_rank=self.rank,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
test_method(self, *arg, **kwargs)
|
||||
dist.join_rpc()
|
||||
rpc.join_rpc()
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
120
test/rpc_test.py
120
test/rpc_test.py
|
|
@ -13,12 +13,12 @@ import torch.distributed.rpc as rpc
|
|||
from common_utils import load_tests
|
||||
from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init
|
||||
from torch.distributed import ProcessGroupAgent
|
||||
from torch.distributed.rpc import RpcBackend
|
||||
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
||||
from torch.distributed.rpc_api import RpcBackend
|
||||
|
||||
|
||||
def requires_process_group_agent(func):
|
||||
from torch.distributed.rpc_api import _agent
|
||||
from torch.distributed.rpc.api import _agent
|
||||
|
||||
return unittest.skipUnless(
|
||||
isinstance(_agent, ProcessGroupAgent),
|
||||
|
|
@ -118,7 +118,7 @@ def no_result():
|
|||
|
||||
|
||||
def nested_rpc(dst):
|
||||
return dist.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
|
||||
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
|
||||
|
||||
|
||||
def multi_layer_nested_async_rpc(dst, world_size, ttl):
|
||||
|
|
@ -127,7 +127,7 @@ def multi_layer_nested_async_rpc(dst, world_size, ttl):
|
|||
if ttl > 0:
|
||||
current_dst = "worker{}".format(dst)
|
||||
next_dst = (dst + 1) % world_size
|
||||
dist.rpc_async(
|
||||
rpc.rpc_async(
|
||||
current_dst,
|
||||
multi_layer_nested_async_rpc,
|
||||
args=(next_dst, world_size, ttl - 1),
|
||||
|
|
@ -202,32 +202,32 @@ class RpcTest(object):
|
|||
def test_worker_id(self):
|
||||
n = self.rank + 1
|
||||
peer_rank = n % self.world_size
|
||||
self_worker_info = dist.get_worker_info()
|
||||
peer_worker_info = dist.get_worker_info("worker{}".format(peer_rank))
|
||||
self_worker_info = rpc.get_worker_info()
|
||||
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
|
||||
|
||||
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
|
||||
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
|
||||
unknown_worker_id = dist.get_worker_info("WorkerUnknown")
|
||||
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
|
||||
|
||||
@dist_init
|
||||
def test_self_add(self):
|
||||
self_worker_info = dist.get_worker_info()
|
||||
self_worker_info = rpc.get_worker_info()
|
||||
self_worker_name = "worker{}".format(self.rank)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "does not support making RPC calls to self"
|
||||
):
|
||||
dist.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
||||
rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "does not support making RPC calls to self"
|
||||
):
|
||||
dist.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
|
||||
rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
|
||||
|
||||
@mock.patch.object(torch.distributed.autograd, "_init")
|
||||
@mock.patch.object(torch.distributed.rpc_api, "_init_rref_context")
|
||||
@mock.patch.object(torch.distributed.rpc.api, "_init_rref_context")
|
||||
def test_register_rpc_backend_and_init_rpc_backend(
|
||||
self, mock_init_rref_context, mock_dist_autograd_init
|
||||
):
|
||||
|
|
@ -235,7 +235,7 @@ class RpcTest(object):
|
|||
rpc.register_backend(
|
||||
backend_name, stub_init_rpc_backend_handler
|
||||
)
|
||||
dist.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
|
||||
rpc.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_CONFIG.backend != RpcBackend.PROCESS_GROUP,
|
||||
|
|
@ -244,34 +244,34 @@ class RpcTest(object):
|
|||
def test_duplicate_name(self):
|
||||
dist.init_process_group(backend="gloo", init_method=self.init_method)
|
||||
with self.assertRaisesRegex(RuntimeError, "is not unique"):
|
||||
dist.init_model_parallel(
|
||||
rpc.init_model_parallel(
|
||||
self_name="duplicate_name",
|
||||
backend=TEST_CONFIG.backend,
|
||||
self_rank=self.rank,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
dist.join_rpc()
|
||||
rpc.join_rpc()
|
||||
|
||||
def test_reinit(self):
|
||||
dist.init_process_group(backend="gloo", init_method=self.init_method)
|
||||
dist.init_model_parallel(
|
||||
rpc.init_model_parallel(
|
||||
self_name="worker{}".format(self.rank),
|
||||
backend=TEST_CONFIG.backend,
|
||||
self_rank=self.rank,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
|
||||
dist.init_model_parallel(
|
||||
rpc.init_model_parallel(
|
||||
self_name="worker{}".format(self.rank),
|
||||
backend=TEST_CONFIG.backend,
|
||||
self_rank=self.rank,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
dist.join_rpc()
|
||||
rpc.join_rpc()
|
||||
|
||||
def test_init_invalid_backend(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"):
|
||||
dist.init_model_parallel(
|
||||
rpc.init_model_parallel(
|
||||
self_name="worker{}".format(self.rank),
|
||||
backend="invalid",
|
||||
self_rank=self.rank,
|
||||
|
|
@ -283,30 +283,30 @@ class RpcTest(object):
|
|||
dist.init_process_group(backend="gloo", init_method=self.init_method)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
||||
dist.init_model_parallel(self_name="abc*")
|
||||
rpc.init_model_parallel(self_name="abc*")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
||||
dist.init_model_parallel(self_name=" ")
|
||||
rpc.init_model_parallel(self_name=" ")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
|
||||
dist.init_model_parallel(self_name="")
|
||||
rpc.init_model_parallel(self_name="")
|
||||
|
||||
# If the number in the message does not match, it is likely that the
|
||||
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
|
||||
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
|
||||
dist.init_model_parallel(
|
||||
rpc.init_model_parallel(
|
||||
self_name="".join(["a" for _ in range(500)]),
|
||||
backend=TEST_CONFIG.backend,
|
||||
self_rank=self.rank,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
dist.join_rpc()
|
||||
rpc.join_rpc()
|
||||
|
||||
@dist_init
|
||||
def test_add(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
|
|
@ -317,9 +317,9 @@ class RpcTest(object):
|
|||
def test_add_with_id(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
workder_info = dist.get_worker_info("worker{}".format(dst_rank))
|
||||
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
||||
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
|
||||
)
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
|
|
@ -328,7 +328,7 @@ class RpcTest(object):
|
|||
def test_scalar_add(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
|
||||
)
|
||||
self.assertEqual(ret, (torch.ones(n, n) + n))
|
||||
|
|
@ -337,7 +337,7 @@ class RpcTest(object):
|
|||
def test_async_add(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
fut = dist.rpc_async(
|
||||
fut = rpc.rpc_async(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
|
|
@ -350,7 +350,7 @@ class RpcTest(object):
|
|||
dst_rank = n % self.world_size
|
||||
x = torch.ones(self.world_size, self.world_size)
|
||||
x[self.rank][self.rank] = 0
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
|
||||
self.assertEqual(ret, x.nonzero())
|
||||
|
||||
@dist_init
|
||||
|
|
@ -358,7 +358,7 @@ class RpcTest(object):
|
|||
dst_rank = (self.rank + 1) % self.world_size
|
||||
for i in range(20):
|
||||
n = i + self.rank + 1
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
|
|
@ -369,18 +369,18 @@ class RpcTest(object):
|
|||
def test_sync_rpc(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
for i in range(20):
|
||||
dist.sync_rpc()
|
||||
rpc.sync_rpc()
|
||||
n = i + self.rank + 1
|
||||
ret1 = dist.rpc_sync(
|
||||
ret1 = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
)
|
||||
dist.sync_rpc()
|
||||
ret2 = dist.rpc_sync(
|
||||
rpc.sync_rpc()
|
||||
ret2 = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
||||
)
|
||||
dist.sync_rpc()
|
||||
rpc.sync_rpc()
|
||||
self.assertEqual(ret1, torch.ones(n, n) * 2)
|
||||
self.assertEqual(ret2, torch.ones(n, n) * 3)
|
||||
|
||||
|
|
@ -388,29 +388,29 @@ class RpcTest(object):
|
|||
def test_join_rpc(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
)
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
dist.join_rpc()
|
||||
rpc.join_rpc()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
||||
dist.rpc_sync(
|
||||
rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
)
|
||||
|
||||
# it's safe to call join_rpc() multiple times
|
||||
dist.join_rpc()
|
||||
rpc.join_rpc()
|
||||
|
||||
@dist_init
|
||||
def test_expected_src(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
expected_src_rank = (self.rank - 1) % self.world_size
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
|
||||
value = VALUE_FUTURE.result()
|
||||
self.assertEqual(value, expected_src_rank)
|
||||
|
||||
|
|
@ -418,14 +418,14 @@ class RpcTest(object):
|
|||
def test_py_built_in(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
|
||||
self.assertEqual(ret, min(n, n + 1, n + 2))
|
||||
|
||||
@dist_init
|
||||
def test_py_user_defined(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
my_function,
|
||||
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
||||
|
|
@ -436,14 +436,14 @@ class RpcTest(object):
|
|||
def test_py_class_constructor(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
|
||||
self.assertEqual(ret.a, n)
|
||||
|
||||
@dist_init
|
||||
def test_py_class_instance_method(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
|
||||
)
|
||||
self.assertEqual(ret, MyClass(2).my_instance_method(n))
|
||||
|
|
@ -452,7 +452,7 @@ class RpcTest(object):
|
|||
def test_py_class_method(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
|
||||
)
|
||||
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
|
||||
|
|
@ -461,7 +461,7 @@ class RpcTest(object):
|
|||
def test_py_class_static_method(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
|
||||
)
|
||||
self.assertEqual(ret, MyClass.my_static_method(n + 10))
|
||||
|
|
@ -470,9 +470,9 @@ class RpcTest(object):
|
|||
def test_py_multi_async_call(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
dst_worker_info = dist.get_worker_info("worker{}".format(dst_rank))
|
||||
fut1 = dist.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
|
||||
fut2 = dist.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
|
||||
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
||||
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
|
||||
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
|
||||
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
|
||||
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
|
||||
|
||||
|
|
@ -480,14 +480,14 @@ class RpcTest(object):
|
|||
def test_py_no_return_result(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), no_result)
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
|
||||
self.assertEqual(ret, no_result())
|
||||
|
||||
@dist_init
|
||||
def test_py_tensors(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
my_tensor_function,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
|
|
@ -500,7 +500,7 @@ class RpcTest(object):
|
|||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
for i in range(100):
|
||||
fut = dist.rpc_async(
|
||||
fut = rpc.rpc_async(
|
||||
"worker{}".format(dst_rank),
|
||||
my_tensor_function,
|
||||
args=(torch.ones(i, i), torch.ones(i, i)),
|
||||
|
|
@ -521,7 +521,7 @@ class RpcTest(object):
|
|||
a = [torch.ones(n, n), torch.ones(n, n)]
|
||||
b = TensorClass(build_complex_tensors())
|
||||
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
|
||||
)
|
||||
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
|
||||
|
|
@ -531,7 +531,7 @@ class RpcTest(object):
|
|||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
run_nested_pickle,
|
||||
args=(MyPickleClass(), torch.ones(2, 2)),
|
||||
|
|
@ -546,13 +546,13 @@ class RpcTest(object):
|
|||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
with self.assertRaisesRegex(Exception, "TypeError"):
|
||||
ret = dist.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
|
||||
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
|
||||
|
||||
@dist_init
|
||||
def test_py_raise_in_user_func(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
fut = dist.rpc_async("worker{}".format(dst_rank), raise_func)
|
||||
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
|
||||
with self.assertRaisesRegex(Exception, "ValueError"):
|
||||
fut.wait()
|
||||
|
||||
|
|
@ -560,7 +560,7 @@ class RpcTest(object):
|
|||
def test_nested_rpc(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc_sync(
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
nested_rpc,
|
||||
args=("worker{}".format(self.rank),),
|
||||
|
|
@ -575,7 +575,7 @@ class RpcTest(object):
|
|||
futs = []
|
||||
tik = time.time()
|
||||
for _ in range(repeat):
|
||||
fut = dist.rpc_async("worker{}".format(dst_rank), f, args=args)
|
||||
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
|
||||
futs.append(fut)
|
||||
|
||||
for fut in futs:
|
||||
|
|
@ -599,7 +599,7 @@ class RpcTest(object):
|
|||
def test_builtin_remote_ret(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
rref = dist.remote(
|
||||
rref = rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
|
|
@ -615,7 +615,7 @@ class RpcTest(object):
|
|||
for i in range(m):
|
||||
n = n + i
|
||||
rrefs.append(
|
||||
dist.remote(
|
||||
rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
fn,
|
||||
args=args_fn(n),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
|
||||
def is_available():
|
||||
|
|
@ -19,39 +18,3 @@ if is_available():
|
|||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||
# this.
|
||||
from .distributed_c10d import _backend # noqa: F401
|
||||
if sys.version_info >= (3, 0):
|
||||
from .rpc_api import _init_rpc
|
||||
from .rpc_api import * # noqa: F401
|
||||
|
||||
def init_model_parallel(self_name,
|
||||
backend=RpcBackend.PROCESS_GROUP,
|
||||
self_rank=-1,
|
||||
init_method=None,
|
||||
num_send_recv_threads=4):
|
||||
r"""
|
||||
Initializes model parallel primitives such as the local rpc agent
|
||||
and distributed autograd.
|
||||
|
||||
Initializes the local RPC agent which immediately makes the current
|
||||
process ready to send and receive RPCs. The caller needs to make
|
||||
sure the specified backend is properly intialized before calling
|
||||
this method. For example, to use ``pg`` (ProcessGroup) backend,
|
||||
``init_process_group`` must be invoked prior to this method.
|
||||
|
||||
Arguments:
|
||||
backend (Enum): type of RPC backend implementation.
|
||||
Currently, process group backend is the only
|
||||
available backend implementation. (default:
|
||||
``RpcBackend.PROCESS_GROUP``).
|
||||
self_name (str): a globally unique name of this node. (e.g.,
|
||||
``Trainer3``, ``ParameterServer2``, ``Master``,
|
||||
``Worker1``) Name can only contain number, alphabet,
|
||||
underscore, and/or dash, and must be shorter than
|
||||
128 characters.
|
||||
self_rank (int): a globally unique id/rank of this node.
|
||||
init_method(str): backend specific init arguments.
|
||||
num_send_recv_threads(int): Number of threads for send/recv work.
|
||||
"""
|
||||
_init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads)
|
||||
from .rpc_api import _agent
|
||||
autograd._init(_agent.get_worker_info().id)
|
||||
|
|
|
|||
|
|
@ -1 +1,44 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import sys
|
||||
|
||||
from .backend_registry import * # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
from .api import _init_rpc
|
||||
from .api import * # noqa: F401
|
||||
import torch.distributed.autograd
|
||||
|
||||
def init_model_parallel(self_name,
|
||||
backend=RpcBackend.PROCESS_GROUP,
|
||||
self_rank=-1,
|
||||
init_method=None,
|
||||
num_send_recv_threads=4):
|
||||
r"""
|
||||
Initializes model parallel primitives such as the local rpc agent
|
||||
and distributed autograd.
|
||||
|
||||
Initializes the local RPC agent which immediately makes the current
|
||||
process ready to send and receive RPCs. The caller needs to make
|
||||
sure the specified backend is properly intialized before calling
|
||||
this method. For example, to use ``pg`` (ProcessGroup) backend,
|
||||
``init_process_group`` must be invoked prior to this method.
|
||||
|
||||
Arguments:
|
||||
backend (Enum): type of RPC backend implementation.
|
||||
Currently, process group backend is the only
|
||||
available backend implementation. (default:
|
||||
``RpcBackend.PROCESS_GROUP``).
|
||||
self_name (str): a globally unique name of this node. (e.g.,
|
||||
``Trainer3``, ``ParameterServer2``, ``Master``,
|
||||
``Worker1``) Name can only contain number, alphabet,
|
||||
underscore, and/or dash, and must be shorter than
|
||||
128 characters.
|
||||
self_rank (int): a globally unique id/rank of this node.
|
||||
init_method(str): backend specific init arguments.
|
||||
num_send_recv_threads(int): Number of threads for send/recv work.
|
||||
"""
|
||||
_init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads)
|
||||
from .api import _agent
|
||||
torch.distributed.autograd._init(_agent.get_worker_info().id)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from . import invoke_rpc_builtin, invoke_rpc_python_udf
|
||||
from . import invoke_remote_builtin, invoke_remote_python_udf
|
||||
from . import _init_rref_context, _destroy_rref_context
|
||||
from . import ProcessGroupAgent
|
||||
from . import WorkerInfo
|
||||
from .rpc import is_backend_registered, init_backend
|
||||
from .rpc.internal import _internal_rpc_pickler, PythonUDF
|
||||
from torch.distributed import invoke_rpc_builtin, invoke_rpc_python_udf
|
||||
from torch.distributed import invoke_remote_builtin, invoke_remote_python_udf
|
||||
from torch.distributed import _init_rref_context, _destroy_rref_context
|
||||
from torch.distributed import ProcessGroupAgent
|
||||
from torch.distributed import WorkerInfo
|
||||
from .backend_registry import is_backend_registered, init_backend
|
||||
from .internal import _internal_rpc_pickler, PythonUDF
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
|
@ -69,7 +69,7 @@ def _init_rpc(backend=RpcBackend.PROCESS_GROUP,
|
|||
raise RuntimeError("RPC is already initialized")
|
||||
|
||||
if backend == RpcBackend.PROCESS_GROUP:
|
||||
from .distributed_c10d import _get_default_group
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
group = _get_default_group()
|
||||
if (self_rank != -1) and (self_rank != group.rank()):
|
||||
|
|
@ -140,19 +140,20 @@ def remote(to, func, args=None, kwargs=None):
|
|||
|
||||
On worker 0:
|
||||
>>> import torch.distributed as dist
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
||||
>>> dist.init_rpc("worker0")
|
||||
>>> worker1 = dist.get_worker_info("worker1")
|
||||
>>> rref1 = dist.remote(worker1, torch.add, args=(torch.ones(2), 3))
|
||||
>>> rref2 = dist.remote(worker1, torch.add, args=(torch.ones(2), 1))
|
||||
>>> rpc.init_rpc("worker0")
|
||||
>>> worker1 = rpc.get_worker_info("worker1")
|
||||
>>> rref1 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 3))
|
||||
>>> rref2 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 1))
|
||||
>>> x = rref1.to_here() + rref2.to_here()
|
||||
>>> dist.join_rpc()
|
||||
>>> rpc.join_rpc()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
||||
>>> dist.init_rpc("worker1")
|
||||
>>> dist.join_rpc()
|
||||
>>> rpc.join_rpc()
|
||||
"""
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
|
||||
|
|
@ -213,16 +214,18 @@ def rpc_sync(to, func, args=None, kwargs=None):
|
|||
Example::
|
||||
On worker 0:
|
||||
>>> import torch.distributed as dist
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
||||
>>> dist.init_model_parallel("worker0")
|
||||
>>> ret = dist.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
|
||||
>>> dist.join_rpc()
|
||||
>>> rpc.init_model_parallel("worker0")
|
||||
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
|
||||
>>> rpc.join_rpc()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed as dist
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
||||
>>> dist.init_model_parallel("worker1")
|
||||
>>> dist.join_rpc()
|
||||
>>> rpc.init_model_parallel("worker1")
|
||||
>>> rpc.join_rpc()
|
||||
"""
|
||||
fut = _invoke_rpc(to, func, args, kwargs)
|
||||
return fut.wait()
|
||||
|
|
@ -253,19 +256,21 @@ def rpc_async(to, func, args=None, kwargs=None):
|
|||
|
||||
On worker 0:
|
||||
>>> import torch.distributed as dist
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
||||
>>> dist.init_model_parallel("worker0")
|
||||
>>> worker1 = dist.get_worker_id("worker1")
|
||||
>>> fut1 = dist.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
|
||||
>>> fut2 = dist.rpc_async(worker1, min, args=(1, 2))
|
||||
>>> rpc.init_model_parallel("worker0")
|
||||
>>> worker1 = rpc.get_worker_id("worker1")
|
||||
>>> fut1 = rpc.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
|
||||
>>> fut2 = rpc.rpc_async(worker1, min, args=(1, 2))
|
||||
>>> result = fut1.wait() + fut2.wait()
|
||||
>>> dist.join_rpc()
|
||||
>>> rpc.join_rpc()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed as dist
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
||||
>>> dist.init_model_parallel("worker1")
|
||||
>>> dist.join_rpc()
|
||||
>>> rpc.init_model_parallel("worker1")
|
||||
>>> rpc.join_rpc()
|
||||
"""
|
||||
fut = _invoke_rpc(to, func, args, kwargs)
|
||||
return fut
|
||||
Loading…
Reference in New Issue
Block a user