mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30610 Test Plan: Imported from OSS Differential Revision: D18763592 Pulled By: mrshenli fbshipit-source-id: ad8854bdb6250c29eaa0f582d66cfd31394312e5
1321 lines
43 KiB
Python
1321 lines
43 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import concurrent.futures
|
|
from datetime import timedelta
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from collections import namedtuple
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.rpc as rpc
|
|
from torch.distributed.rpc import RRef
|
|
from common_utils import load_tests
|
|
import dist_utils
|
|
from dist_utils import dist_init
|
|
from torch.distributed.rpc.api import _use_rpc_pickler
|
|
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
|
from rpc_agent_test_fixture import RpcAgentTestFixture
|
|
|
|
rpc_done = [False, False, False, False]
|
|
|
|
# TODO: dedupe this with the code in dist_autograd_test.py.
|
|
# Send rpc done info and context_id to
|
|
# dst_rank = (self.rank + rank_distance) % self.world_size
|
|
# we don't need a lock here since the GIL is held while executing remote
|
|
# python UDFs, so access is serialized across several workers.
|
|
def _set_rpc_done(rank_distance):
|
|
global rpc_done
|
|
rpc_done[rank_distance] = True
|
|
|
|
def _check_rpc_done(rank_distance):
|
|
while not rpc_done[rank_distance]:
|
|
# yield control to other threads
|
|
time.sleep(0)
|
|
|
|
def requires_process_group_agent(message=""):
|
|
def decorator(old_func):
|
|
return unittest.skipUnless(
|
|
dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP", message
|
|
)(old_func)
|
|
|
|
return decorator
|
|
|
|
|
|
VALUE_FUTURE = concurrent.futures.Future()
|
|
|
|
|
|
def _stub_construct_rpc_backend_options_handler(
|
|
**kwargs
|
|
):
|
|
return mock.Mock() # RpcBackendOptions.
|
|
|
|
|
|
def _stub_start_rpc_backend_handler(
|
|
store, name, rank, world_size, rpc_backend_options
|
|
):
|
|
return mock.Mock() # RpcAgent.
|
|
|
|
|
|
def set_value(value):
|
|
VALUE_FUTURE.set_result(value)
|
|
|
|
|
|
# it is used to test python user defined function over rpc
|
|
# classes and functions are used to test python user defined class and
|
|
# methods over rpc
|
|
TensorClass = namedtuple("TensorClass", ["tensors"])
|
|
|
|
|
|
class MyPickleClass:
|
|
def __init__(self):
|
|
self.t = None
|
|
|
|
def __getstate__(self):
|
|
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
|
|
PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
|
|
)
|
|
return (pickled_python_udf, tensors)
|
|
|
|
def __setstate__(self, obj):
|
|
python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
|
|
result = python_udf.func(python_udf.args[0], python_udf.args[1])
|
|
self.t = result
|
|
|
|
def set(self, val):
|
|
self.t = val
|
|
|
|
|
|
class MyClass:
|
|
def __init__(self, a):
|
|
self.a = a
|
|
|
|
def my_instance_method(self, b):
|
|
return self.a + b
|
|
|
|
@classmethod
|
|
def my_class_method(cls, d, e):
|
|
return d + e
|
|
|
|
@staticmethod
|
|
def my_static_method(f):
|
|
return f > 10
|
|
|
|
def increment_value(self, increment):
|
|
self.a += increment
|
|
|
|
def get_value(self):
|
|
return self.a
|
|
|
|
|
|
def _call_method_on_rref(method, rref, *args, **kwargs):
|
|
return method(rref.local_value(), *args, **kwargs)
|
|
|
|
|
|
def get_rref_list(values):
|
|
return [RRef(MyClass(a)) for a in values]
|
|
|
|
|
|
def add_rref_to_value(rref, value):
|
|
return rref.to_here() + value
|
|
|
|
|
|
def run_nested_pickle(pickle_cls_instance, tensor):
|
|
return pickle_cls_instance.t + tensor
|
|
|
|
|
|
def build_complex_tensors():
|
|
a = torch.ones(3, 3)
|
|
b = [a, a]
|
|
c = [b, b]
|
|
d = [a, b]
|
|
e = {a: d}
|
|
return [a, b, c, d, e]
|
|
|
|
|
|
def my_function(a, b, c):
|
|
return a + b + c
|
|
|
|
|
|
def my_tensor_function(a, b):
|
|
return a + b
|
|
|
|
def my_sleep_func(seconds=1):
|
|
time.sleep(seconds)
|
|
|
|
|
|
def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
|
|
res = list_input[0]
|
|
for t in list_input:
|
|
res += t
|
|
for k, v in dict_input.items():
|
|
res += v
|
|
complex_tensors = tensor_class_input.tensors
|
|
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
|
|
|
|
|
|
def my_rref_function(rref_a, rref_b):
|
|
return rref_a.to_here() + rref_b.to_here()
|
|
|
|
|
|
def no_result():
|
|
print("do nothing")
|
|
|
|
|
|
def nested_rpc(dst):
|
|
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
|
|
def multi_layer_nested_async_rpc(dst, world_size, ttl):
|
|
# this method returns immediately without blocking the callee, but will
|
|
# generate additional requests.
|
|
if ttl > 0:
|
|
current_dst = "worker{}".format(dst)
|
|
next_dst = (dst + 1) % world_size
|
|
rpc.rpc_async(
|
|
current_dst,
|
|
multi_layer_nested_async_rpc,
|
|
args=(next_dst, world_size, ttl - 1),
|
|
)
|
|
return 0
|
|
|
|
|
|
def nested_rref(dst):
|
|
return (
|
|
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
|
|
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
|
|
)
|
|
|
|
|
|
def nested_remote(dst):
|
|
rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
|
|
return rref.to_here()
|
|
|
|
|
|
def rref_forward_chain(dst, world_size, rref, ttl):
|
|
if ttl > 0:
|
|
current_dst = "worker{}".format(dst)
|
|
next_dst = (dst + 1) % world_size
|
|
ret_rref = rpc.remote(
|
|
current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
|
|
)
|
|
return [ret_rref]
|
|
else:
|
|
return rref.to_here()
|
|
|
|
|
|
def rpc_return_rref(dst):
|
|
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
|
|
def light_rpc():
|
|
return 0
|
|
|
|
|
|
def heavy_rpc(tensor):
|
|
for i in range(1, 100):
|
|
tensor *= i
|
|
tensor /= i + 1
|
|
return 0
|
|
|
|
|
|
def raise_func():
|
|
raise ValueError("Expected error")
|
|
|
|
global_rref = None
|
|
|
|
def set_global_rref(rref):
|
|
global global_rref
|
|
global_rref = rref
|
|
|
|
def clear_global_rref():
|
|
global global_rref
|
|
global_rref = None
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
|
|
@unittest.skipIf(
|
|
sys.version_info < (3, 0),
|
|
"Pytorch distributed rpc package " "does not support python2",
|
|
)
|
|
class RpcTest(RpcAgentTestFixture):
|
|
@dist_init
|
|
def test_worker_id(self):
|
|
n = self.rank + 1
|
|
peer_rank = n % self.world_size
|
|
self_worker_info = rpc.get_worker_info()
|
|
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
|
|
|
|
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
|
|
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
|
|
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
|
|
|
|
@dist_init
|
|
def test_get_worker_infos(self):
|
|
worker_infos = rpc.api._agent.get_worker_infos()
|
|
|
|
worker_names = {
|
|
worker_info.name for worker_info in worker_infos
|
|
}
|
|
expected_worker_names = {
|
|
"worker{}".format(rank) for rank in range(self.world_size)
|
|
}
|
|
self.assertEqual(worker_names, expected_worker_names)
|
|
|
|
worker_ids = {
|
|
worker_info.id for worker_info in worker_infos
|
|
}
|
|
expected_worker_ids = {
|
|
rank for rank in range(self.world_size)
|
|
}
|
|
self.assertEqual(worker_ids, expected_worker_ids)
|
|
|
|
@dist_init
|
|
def test_self_add(self):
|
|
self_worker_info = rpc.get_worker_info()
|
|
self_worker_name = "worker{}".format(self.rank)
|
|
fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
|
ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
|
|
|
@dist_init
|
|
def test_self_py_udf_remote(self):
|
|
self_worker_info = rpc.get_worker_info()
|
|
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3)
|
|
|
|
def _test_self_remote_rref_as_rpc_arg(self, dst):
|
|
self_worker_info = rpc.get_worker_info()
|
|
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
|
|
fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
|
|
ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1))
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_rpc_arg(self):
|
|
dst = "worker{}".format((self.rank + 1) % self.world_size)
|
|
self._test_self_remote_rref_as_rpc_arg(dst)
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_self_rpc_arg(self):
|
|
self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info())
|
|
|
|
def _test_self_remote_rref_as_remote_arg(self, dst):
|
|
self_worker_info = rpc.get_worker_info()
|
|
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
|
|
ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
|
|
self.assertEqual(ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_remote_arg(self):
|
|
dst = "worker{}".format((self.rank + 1) % self.world_size)
|
|
self._test_self_remote_rref_as_remote_arg(dst)
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_self_remote_arg(self):
|
|
self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info())
|
|
|
|
@mock.patch.object(torch.distributed.autograd, "_init")
|
|
@mock.patch.object(torch.distributed.rpc.api, "_start_rpc_agent")
|
|
@dist_init(setup_rpc=False)
|
|
def test_register_rpc_backend_and_start_rpc_backend(
|
|
self, mock_rpc_agent, mock_dist_autograd_init
|
|
):
|
|
backend_name = "stub_backend"
|
|
|
|
backend = rpc.backend_registry.register_backend(
|
|
backend_name,
|
|
_stub_construct_rpc_backend_options_handler,
|
|
_stub_start_rpc_backend_handler,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "^RPC backend .+: already registered$"
|
|
):
|
|
backend = rpc.backend_registry.register_backend(
|
|
backend_name,
|
|
_stub_construct_rpc_backend_options_handler,
|
|
_stub_start_rpc_backend_handler,
|
|
)
|
|
|
|
rpc.init_rpc(
|
|
name="worker1",
|
|
backend=backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
@dist_init(setup_rpc=False)
|
|
def test_duplicate_name(self):
|
|
with self.assertRaisesRegex(RuntimeError, "is not unique"):
|
|
store, _, _ = next(torch.distributed.rendezvous(
|
|
self.init_method, rank=self.rank, world_size=self.world_size
|
|
))
|
|
rpc._init_rpc_backend(
|
|
backend=self.rpc_backend,
|
|
store=store,
|
|
name="duplicate_name",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
rpc.shutdown()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_reinit(self):
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
# This is for the below `dist.barrier`.
|
|
# For `RpcAgent` other than `ProcessGroupAgent`,
|
|
# no `_default_pg` is initialized.
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method=self.init_method,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
)
|
|
# Wait for all init to complete.
|
|
dist.barrier()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
rpc.shutdown()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_invalid_names(self):
|
|
from torch.distributed.rpc import WorkerInfo
|
|
worker_id = 0
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
|
info = WorkerInfo("abc*", worker_id)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
|
info = WorkerInfo(" ", worker_id)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
|
|
info = WorkerInfo("", worker_id)
|
|
|
|
# If the number in the message does not match, it is likely that the
|
|
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
|
|
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
|
|
info = WorkerInfo("".join(["a" for i in range(500)]), worker_id)
|
|
|
|
@dist_init
|
|
def test_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_add_with_id(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
|
|
|
ret = rpc.rpc_sync(
|
|
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_scalar_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
|
|
)
|
|
self.assertEqual(ret, (torch.ones(n, n) + n))
|
|
|
|
@dist_init
|
|
def test_async_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
fut = rpc.rpc_async(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_nonzero(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
x = torch.ones(self.world_size, self.world_size)
|
|
x[self.rank][self.rank] = 0
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
|
|
self.assertEqual(ret, x.nonzero())
|
|
|
|
@dist_init
|
|
def test_multi_rpc(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
for i in range(20):
|
|
n = i + self.rank + 1
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_shutdown(self):
|
|
# Initialize RPC.
|
|
rpc.init_rpc(
|
|
name="worker%d" % self.rank,
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
rpc.shutdown()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
|
rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
|
|
# it's safe to call shutdown() multiple times
|
|
rpc.shutdown()
|
|
|
|
@dist_init
|
|
def test_expected_src(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
expected_src_rank = (self.rank - 1) % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
|
|
value = VALUE_FUTURE.result()
|
|
self.assertEqual(value, expected_src_rank)
|
|
|
|
@dist_init
|
|
def test_py_built_in(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
|
|
self.assertEqual(ret, min(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_user_defined(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
my_function,
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
|
)
|
|
self.assertEqual(ret, my_function(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_class_constructor(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
|
|
self.assertEqual(ret.a, n)
|
|
|
|
@dist_init
|
|
def test_py_class_instance_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
|
|
)
|
|
self.assertEqual(ret, MyClass(2).my_instance_method(n))
|
|
|
|
@dist_init
|
|
def test_py_class_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
|
|
)
|
|
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
|
|
|
|
@dist_init
|
|
def test_py_class_static_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
|
|
)
|
|
self.assertEqual(ret, MyClass.my_static_method(n + 10))
|
|
|
|
@dist_init
|
|
def test_py_multi_async_call(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
|
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
|
|
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
|
|
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
|
|
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_no_return_result(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
|
|
self.assertEqual(ret, no_result())
|
|
|
|
@dist_init
|
|
def test_py_tensors(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
my_tensor_function,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
|
|
|
|
@dist_init
|
|
def test_py_tensors_multi_async_call(self):
|
|
futs = []
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
for i in range(100):
|
|
fut = rpc.rpc_async(
|
|
"worker{}".format(dst_rank),
|
|
my_tensor_function,
|
|
args=(torch.ones(i, i), torch.ones(i, i)),
|
|
)
|
|
futs.append(fut)
|
|
|
|
j = 0
|
|
for fut in futs:
|
|
self.assertEqual(
|
|
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
|
)
|
|
j += 1
|
|
|
|
@dist_init
|
|
def test_py_tensors_in_container(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
a = [torch.ones(n, n), torch.ones(n, n)]
|
|
b = TensorClass(build_complex_tensors())
|
|
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
|
|
)
|
|
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
|
|
|
|
@dist_init
|
|
def test_py_nested_pickle(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
run_nested_pickle,
|
|
args=(MyPickleClass(), torch.ones(2, 2)),
|
|
)
|
|
|
|
m = MyPickleClass()
|
|
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
|
|
self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
|
|
|
|
@dist_init
|
|
def test_py_function_exception(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
with self.assertRaisesRegex(Exception, "TypeError"):
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
|
|
|
|
@dist_init
|
|
def test_py_raise_in_user_func(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
|
|
with self.assertRaisesRegex(Exception, "ValueError"):
|
|
fut.wait()
|
|
|
|
@dist_init
|
|
def test_nested_rpc(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
nested_rpc,
|
|
args=("worker{}".format(self.rank),),
|
|
)
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
|
|
|
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
futs = []
|
|
tik = time.time()
|
|
for _ in range(repeat):
|
|
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
self.assertEqual(fut.wait(), 0)
|
|
tok = time.time()
|
|
print(
|
|
"Rank {} finished testing {} {} times in {} seconds.".format(
|
|
self.rank, f.__name__, repeat, tok - tik
|
|
)
|
|
)
|
|
|
|
@dist_init
|
|
def test_stress_light_rpc(self):
|
|
self._stress_test_rpc(light_rpc)
|
|
|
|
@dist_init
|
|
def test_stress_heavy_rpc(self):
|
|
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
|
|
|
|
@dist_init
|
|
def test_builtin_remote_ret(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_asymmetric_load_with_join(self):
|
|
"""Test graceful termination."""
|
|
# worker0 drives and waits for worker1 and worker2
|
|
# throughout the test.
|
|
if self.rank == 0:
|
|
assert self.world_size >= 3
|
|
|
|
num_repeat = 100
|
|
futs = []
|
|
|
|
# Phase 1: Only worker1 has workload.
|
|
dst = "worker1"
|
|
for _ in range(num_repeat):
|
|
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
fut.wait()
|
|
self.assertEqual(fut.wait(), 0)
|
|
|
|
# Phase 2: Only worker2 has workload.
|
|
# If join is not correctly implemented,
|
|
# worker2 should be closed by now.
|
|
dst = "worker2"
|
|
for _ in range(num_repeat):
|
|
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
fut.wait()
|
|
self.assertEqual(fut.wait(), 0)
|
|
|
|
def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
|
|
m = 10
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rrefs = []
|
|
expected = []
|
|
for i in range(m):
|
|
n = n + i
|
|
rrefs.append(
|
|
rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
fn,
|
|
args=args_fn(n),
|
|
kwargs=kwargs_fn(n),
|
|
)
|
|
)
|
|
expected.append(fn(*args_fn(n), **kwargs_fn(n)))
|
|
|
|
for i in range(m):
|
|
self.assertEqual(rrefs[i].to_here(), expected[i])
|
|
|
|
@dist_init
|
|
def test_multi_builtin_remote_ret(self):
|
|
def args_fn(n):
|
|
return (torch.ones(n, n), torch.ones(n, n))
|
|
|
|
self._test_multi_remote_call(torch.add, args_fn=args_fn)
|
|
|
|
@dist_init
|
|
def test_py_udf_remote(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
my_function,
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
|
)
|
|
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_multi_py_udf_remote(self):
|
|
def kwargs_fn(n):
|
|
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
|
|
|
|
self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
|
|
|
|
@dist_init
|
|
def test_py_rref_args(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
rref_c = rpc.remote(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
def test_py_rref_args_user_share(self):
|
|
n = self.rank + 1
|
|
owner_rank = n % self.world_size
|
|
user_rank = (n + 1) % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
|
|
)
|
|
rref_c = rpc.remote(
|
|
"worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
def test_py_rpc_rref_args(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
|
|
)
|
|
|
|
c = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
|
|
self.assertEqual(c, torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
def test_nested_remote(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_remote,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
|
|
|
|
@dist_init
|
|
def test_nested_rref(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref_of_rrefs = rpc.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
rrefs = rref_of_rrefs.to_here()
|
|
self.assertEqual(len(rrefs), 2)
|
|
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
|
|
|
|
@dist_init
|
|
def test_nested_rref_stress(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
all_rrefs = []
|
|
for _ in range(20):
|
|
all_rrefs.append(
|
|
rpc.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
)
|
|
|
|
for i in range(20):
|
|
rref_of_rrefs = all_rrefs[i]
|
|
rrefs = rref_of_rrefs.to_here()
|
|
self.assertEqual(len(rrefs), 2)
|
|
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
|
|
|
|
@dist_init
|
|
def test_multi_layer_nested_async_rpc(self):
|
|
# This test will exit right away, but there will be a chain of async
|
|
# RPCs. The termination algorithm should detect those messages properly.
|
|
# Otherwise, some peer could exit early, leaving others to timeout
|
|
# errors or connection closed errors.
|
|
ttl = 20
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
|
|
|
|
@dist_init
|
|
def test_remote_with_exception(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
# check ref to other workers
|
|
rref = rpc.remote("worker{}".format(dst_rank), raise_func)
|
|
with self.assertRaisesRegex(Exception, "ValueError"):
|
|
rref.to_here()
|
|
# check ref to itself
|
|
rref = rpc.remote("worker{}".format(self.rank), no_result, args=(10,))
|
|
with self.assertRaisesRegex(Exception, "TypeError"):
|
|
rref.to_here()
|
|
|
|
@dist_init
|
|
def test_rpc_return_rref(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank1),
|
|
rpc_return_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
|
|
|
|
@dist_init
|
|
def test_rref_forward_chain(self):
|
|
ttl = 8
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
|
|
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
|
|
|
|
for i in range(ttl):
|
|
self.assertEqual(len(ret_rref), 1)
|
|
ret_rref = ret_rref[0].to_here()
|
|
|
|
ret = ret_rref
|
|
self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
|
|
|
|
@dist_init
|
|
def test_local_rref_no_fork(self):
|
|
local_rref = RRef(35)
|
|
self.assertEqual(local_rref.local_value(), 35)
|
|
|
|
@dist_init
|
|
def test_return_local_rrefs(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
rref_list = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), get_rref_list, args=(
|
|
[1, 2, 3], ))
|
|
|
|
for rref in rref_list:
|
|
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, 10))
|
|
|
|
rets = [
|
|
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.get_value, rref))
|
|
for rref in rref_list]
|
|
|
|
self.assertEqual(rets, [11, 12, 13])
|
|
|
|
@dist_init
|
|
def test_owner_equality(self):
|
|
a = RRef(40)
|
|
b = RRef(50)
|
|
|
|
other_rank = (self.rank + 1) % self.world_size
|
|
other_a = rpc.remote(
|
|
"worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)
|
|
)
|
|
other_b = rpc.remote(
|
|
"worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)
|
|
)
|
|
other_a.to_here() # to ensure clean termination
|
|
other_b.to_here()
|
|
|
|
self.assertNotEqual(a.owner(), 23)
|
|
self.assertEqual(other_a.owner(), other_b.owner())
|
|
self.assertNotEqual(a.owner(), other_a.owner())
|
|
self.assertEqual(other_a.owner(), other_a.owner())
|
|
self.assertEqual(other_a.owner(), other_b.owner())
|
|
self.assertEqual(a.owner(), a.owner())
|
|
self.assertEqual(a.owner(), b.owner())
|
|
self.assertEqual(a.owner(), rpc.get_worker_info())
|
|
x = dict()
|
|
x[a.owner()] = a
|
|
x[other_a.owner()] = other_a
|
|
self.assertEqual(x[a.owner()], a)
|
|
self.assertEqual(x[b.owner()], a)
|
|
self.assertEqual(x[other_a.owner()], other_a)
|
|
self.assertEqual(x[other_b.owner()], other_a)
|
|
self.assertEqual(len(x), 2)
|
|
|
|
@dist_init
|
|
def test_pass_local_rrefs(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
dst_worker = "worker{}".format(dst_rank)
|
|
|
|
rref = RRef(40)
|
|
self.assertEqual(
|
|
rpc.rpc_sync(
|
|
dst_worker, add_rref_to_value, args=(rref, 50)), 90)
|
|
self.assertEqual(
|
|
rpc.rpc_async(
|
|
dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90)
|
|
self.assertEqual(
|
|
rpc.remote(
|
|
dst_worker,
|
|
add_rref_to_value,
|
|
args=(rref, 50)).to_here(), 90)
|
|
|
|
@dist_init
|
|
def test_remote_same_worker(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
rref_c = rpc.remote(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init(setup_rpc=True)
|
|
def test_call_method_on_rref(self):
|
|
"""
|
|
Tests that it is possible to call an instance method on a remote objet
|
|
by using rref.owner() as destination of the call.
|
|
"""
|
|
vals = [10, 2, 5, 7]
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
dst_worker = "worker{}".format(dst_rank)
|
|
|
|
# creates a remote object
|
|
rref = rpc.remote(dst_worker, MyClass, args=(vals[0], ))
|
|
|
|
# modifies state of the remote object
|
|
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, vals[1]))
|
|
rpc.rpc_async(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, vals[2])).wait()
|
|
rpc.remote(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, vals[3])).to_here()
|
|
|
|
# queries state of the remote object
|
|
result = rpc.rpc_sync(dst_worker, _call_method_on_rref, args=(
|
|
MyClass.get_value, rref))
|
|
|
|
self.assertEqual(result, sum(vals))
|
|
|
|
def _test_rref_leak(self, ignore_leak=False):
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
# This is for the below `dist.barrier`.
|
|
# For `RpcAgent` other than `ProcessGroupAgent`,
|
|
# no `_default_pg` is initialized.
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method=self.init_method,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
)
|
|
# Wait for all init to complete.
|
|
dist.barrier()
|
|
|
|
rref = rpc.remote(
|
|
"worker{}".format((self.rank + 1) % self.world_size),
|
|
torch.add,
|
|
args=(torch.ones(2, 2), 1)
|
|
)
|
|
|
|
if ignore_leak:
|
|
import torch.distributed.rpc.api as api
|
|
api._ignore_rref_leak = True
|
|
|
|
rpc.shutdown()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_rref_leak(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
|
|
self._test_rref_leak()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_ignore_rref_leak(self):
|
|
self._test_rref_leak(ignore_leak=True)
|
|
|
|
@dist_init
|
|
def test_rref_str(self):
|
|
rref1 = RRef(self.rank)
|
|
id_class = "GloballyUniqueId"
|
|
self.assertEqual(
|
|
"OwnerRRef({}({}, 0))".format(id_class, self.rank),
|
|
rref1.__str__()
|
|
)
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
rref2 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
|
|
self.assertEqual(
|
|
rref2.__str__(),
|
|
"UserRRef(RRefId = {0}({1}, 1), ForkId = {0}({1}, 2))".format(id_class, self.rank)
|
|
)
|
|
|
|
@dist_init
|
|
def test_rref_context_debug_info(self):
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method=self.init_method,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
)
|
|
|
|
from torch.distributed.rpc import _get_debug_info
|
|
rref1 = RRef(self.rank)
|
|
info = _get_debug_info()
|
|
self.assertIn("num_owner_rrefs", info)
|
|
# RRef on local value is not added to context until shared across RPC
|
|
self.assertEqual("0", info["num_owner_rrefs"])
|
|
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
set_global_rref,
|
|
args=(rref1,)
|
|
)
|
|
info = _get_debug_info()
|
|
self.assertIn("num_owner_rrefs", info)
|
|
self.assertEqual("1", info["num_owner_rrefs"])
|
|
rpc.rpc_sync("worker{}".format(dst_rank), clear_global_rref)
|
|
|
|
|
|
rref2 = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(2, 2), 1)
|
|
)
|
|
rref3 = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(2, 2), 1)
|
|
)
|
|
rref2.to_here()
|
|
rref3.to_here()
|
|
|
|
# Use a barrier to make sure that OwnerRRefs are created on this worker
|
|
# before checking debug info
|
|
dist.barrier()
|
|
info = _get_debug_info()
|
|
self.assertIn("num_owner_rrefs", info)
|
|
self.assertEqual("2", info["num_owner_rrefs"])
|
|
|
|
# Use another barrier to make sure that UserRRefs are only deleted after
|
|
# checking debug info
|
|
dist.barrier()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
def test_local_shutdown(self):
|
|
# test that we can start RPC and then immediately locally shutdown
|
|
# without sending any messages.
|
|
rpc.init_rpc(
|
|
name="worker%d" % self.rank,
|
|
backend=rpc.backend_registry.BackendType[
|
|
dist_utils.TEST_CONFIG.rpc_backend_name
|
|
],
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
# pass in graceful=False to ensure that we don't wait for other workers.
|
|
rpc.shutdown(graceful=False)
|
|
|
|
@dist_init(setup_rpc=False)
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
def test_local_shutdown_with_rpc(self):
|
|
# test that we can start RPC, send RPCs, and then run local shutdown.
|
|
rpc.init_rpc(
|
|
name="worker%d" % self.rank,
|
|
backend=rpc.backend_registry.BackendType[
|
|
dist_utils.TEST_CONFIG.rpc_backend_name
|
|
],
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
# wait for RPCs to be done, so that some workers don't try to shut down
|
|
# too early.
|
|
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(1,))
|
|
_check_rpc_done(1)
|
|
# pass in graceful=False to ensure that we don't wait for other workers.
|
|
rpc.shutdown(graceful=False)
|
|
|
|
@dist_init(setup_rpc=False)
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
def test_wait_all_workers_and_shutdown(self):
|
|
# This tests ensures that both rpc._wait_all_workers() and rpc.shutdown() can be
|
|
# called without errors being raised due to attempting to shut down
|
|
# multiple times.
|
|
rpc.init_rpc(
|
|
name="worker%d" % self.rank,
|
|
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options
|
|
)
|
|
from torch.distributed.rpc.api import _wait_all_workers
|
|
# intentional call to internal _wait_all_workers.
|
|
_wait_all_workers()
|
|
rpc.shutdown()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_get_rpc_timeout(self):
|
|
timeout = timedelta(seconds=1)
|
|
|
|
# A new `RpcBackendOptions` is constructed
|
|
# when accessing `self.rpc_backend_options`.
|
|
rpc_backend_options = self.rpc_backend_options
|
|
rpc_backend_options.rpc_timeout = timeout
|
|
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=rpc_backend_options,
|
|
)
|
|
set_timeout = rpc.get_rpc_timeout()
|
|
self.assertEqual(timeout, set_timeout)
|
|
rpc.shutdown()
|
|
|
|
@dist_init
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
def test_rpc_timeouts(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
rpc._set_rpc_timeout(timedelta(milliseconds=1))
|
|
# futures should time out and be marked with an exception indicating it as such.
|
|
futs = [rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()) for _ in range(10)]
|
|
for fut in futs:
|
|
with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"):
|
|
fut.wait()
|
|
|
|
# ensure that if a new timeout is set old futures don't time out but new ones do.
|
|
rpc._set_rpc_timeout(timedelta(seconds=200))
|
|
# create a longstanding RPC.
|
|
fut1 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
|
|
# now, set a short timeout.
|
|
rpc._set_rpc_timeout(timedelta(milliseconds=1))
|
|
# f2 should time out, f should not.
|
|
fut2 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
|
|
with self.assertRaises(RuntimeError):
|
|
fut2.wait()
|
|
fut1.wait()
|
|
|
|
# future should run to completion if the timeout is zero.
|
|
rpc._set_rpc_timeout(timedelta(seconds=0))
|
|
rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()).wait()
|
|
|
|
# reset to default timeout so shutdown messages can process cleanly.
|
|
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT)
|
|
|
|
def test_requires_process_group_agent_decorator(self):
|
|
@requires_process_group_agent("test_func did not run")
|
|
def test_func():
|
|
return "expected result"
|
|
|
|
if dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP":
|
|
self.assertEqual(test_func(), "expected result")
|
|
|
|
def test_dist_init_decorator(self):
|
|
@dist_init(setup_rpc=False)
|
|
def test_func(self):
|
|
return "expected result"
|
|
|
|
self.assertEqual(test_func(self), "expected result")
|
|
|
|
@dist_init
|
|
def test_func(self):
|
|
return "expected result"
|
|
|
|
self.assertEqual(test_func(self), "expected result")
|
|
|
|
def test_use_rpc_pickler(self):
|
|
class TestPickler():
|
|
pass
|
|
test_pickler = TestPickler()
|
|
with _use_rpc_pickler(test_pickler):
|
|
self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
|
|
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
|