mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30217 Before this commit, RRefContext throws an error if it detects any RRef leak during shutdown. However, this requires applications to make sure that is has freed all references to RRefs in application code, which can be a bad debugging experience when for large applications. Besides, this also relies on Python GC to free things up in time, which might not always be true. After this commit, RRefContext would ignore leaking RRefs during shutdown, as shutdown is called when the application has finished training and no longer care about local states. Hence, it should be OK to just ignore those leaks and destroy OwnerRRefs. If application would like to enforce no leaks, just set torch.distributed.rpc.api._ignore_rref_leak to False. Test Plan: Imported from OSS Differential Revision: D18632546 Pulled By: mrshenli fbshipit-source-id: 2744b2401dafdd16de0e0a76cf8e07777bed0f38
1230 lines
40 KiB
Python
1230 lines
40 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import concurrent.futures
|
|
from datetime import timedelta
|
|
import sys
|
|
import unittest
|
|
from collections import namedtuple
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.rpc as rpc
|
|
from torch.distributed.rpc import RRef
|
|
from common_utils import load_tests
|
|
import dist_utils
|
|
from dist_utils import dist_init
|
|
from torch.distributed.rpc.api import _use_rpc_pickler
|
|
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
|
from rpc_agent_test_fixture import RpcAgentTestFixture
|
|
|
|
|
|
def requires_process_group_agent(message=""):
|
|
def decorator(old_func):
|
|
return unittest.skipUnless(
|
|
dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP", message
|
|
)(old_func)
|
|
|
|
return decorator
|
|
|
|
|
|
VALUE_FUTURE = concurrent.futures.Future()
|
|
|
|
|
|
def _stub_construct_rpc_backend_options_handler(
|
|
**kwargs
|
|
):
|
|
return mock.Mock() # RpcBackendOptions.
|
|
|
|
|
|
def _stub_start_rpc_backend_handler(
|
|
store, name, rank, world_size, rpc_backend_options
|
|
):
|
|
return mock.Mock() # RpcAgent.
|
|
|
|
|
|
def set_value(value):
|
|
VALUE_FUTURE.set_result(value)
|
|
|
|
|
|
# it is used to test python user defined function over rpc
|
|
# classes and functions are used to test python user defined class and
|
|
# methods over rpc
|
|
TensorClass = namedtuple("TensorClass", ["tensors"])
|
|
|
|
|
|
class MyPickleClass:
|
|
def __init__(self):
|
|
self.t = None
|
|
|
|
def __getstate__(self):
|
|
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
|
|
PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
|
|
)
|
|
return (pickled_python_udf, tensors)
|
|
|
|
def __setstate__(self, obj):
|
|
python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
|
|
result = python_udf.func(python_udf.args[0], python_udf.args[1])
|
|
self.t = result
|
|
|
|
def set(self, val):
|
|
self.t = val
|
|
|
|
|
|
class MyClass:
|
|
def __init__(self, a):
|
|
self.a = a
|
|
|
|
def my_instance_method(self, b):
|
|
return self.a + b
|
|
|
|
@classmethod
|
|
def my_class_method(cls, d, e):
|
|
return d + e
|
|
|
|
@staticmethod
|
|
def my_static_method(f):
|
|
return f > 10
|
|
|
|
def increment_value(self, increment):
|
|
self.a += increment
|
|
|
|
def get_value(self):
|
|
return self.a
|
|
|
|
|
|
def _call_method_on_rref(method, rref, *args, **kwargs):
|
|
return method(rref.local_value(), *args, **kwargs)
|
|
|
|
|
|
def get_rref_list(values):
|
|
return [RRef(MyClass(a)) for a in values]
|
|
|
|
|
|
def add_rref_to_value(rref, value):
|
|
return rref.to_here() + value
|
|
|
|
|
|
def run_nested_pickle(pickle_cls_instance, tensor):
|
|
return pickle_cls_instance.t + tensor
|
|
|
|
|
|
def build_complex_tensors():
|
|
a = torch.ones(3, 3)
|
|
b = [a, a]
|
|
c = [b, b]
|
|
d = [a, b]
|
|
e = {a: d}
|
|
return [a, b, c, d, e]
|
|
|
|
|
|
def my_function(a, b, c):
|
|
return a + b + c
|
|
|
|
|
|
def my_tensor_function(a, b):
|
|
return a + b
|
|
|
|
def my_sleep_func(seconds=1):
|
|
import time
|
|
time.sleep(seconds)
|
|
|
|
|
|
def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
|
|
res = list_input[0]
|
|
for t in list_input:
|
|
res += t
|
|
for k, v in dict_input.items():
|
|
res += v
|
|
complex_tensors = tensor_class_input.tensors
|
|
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
|
|
|
|
|
|
def my_rref_function(rref_a, rref_b):
|
|
return rref_a.to_here() + rref_b.to_here()
|
|
|
|
|
|
def no_result():
|
|
print("do nothing")
|
|
|
|
|
|
def nested_rpc(dst):
|
|
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
|
|
def multi_layer_nested_async_rpc(dst, world_size, ttl):
|
|
# this method returns immediately without blocking the callee, but will
|
|
# generate additional requests.
|
|
if ttl > 0:
|
|
current_dst = "worker{}".format(dst)
|
|
next_dst = (dst + 1) % world_size
|
|
rpc.rpc_async(
|
|
current_dst,
|
|
multi_layer_nested_async_rpc,
|
|
args=(next_dst, world_size, ttl - 1),
|
|
)
|
|
return 0
|
|
|
|
|
|
def nested_rref(dst):
|
|
return (
|
|
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
|
|
rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
|
|
)
|
|
|
|
|
|
def nested_remote(dst):
|
|
rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
|
|
return rref.to_here()
|
|
|
|
|
|
def rref_forward_chain(dst, world_size, rref, ttl):
|
|
if ttl > 0:
|
|
current_dst = "worker{}".format(dst)
|
|
next_dst = (dst + 1) % world_size
|
|
ret_rref = rpc.remote(
|
|
current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
|
|
)
|
|
return [ret_rref]
|
|
else:
|
|
return rref.to_here()
|
|
|
|
|
|
def rpc_return_rref(dst):
|
|
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
|
|
def light_rpc():
|
|
return 0
|
|
|
|
|
|
def heavy_rpc(tensor):
|
|
for i in range(1, 100):
|
|
tensor *= i
|
|
tensor /= i + 1
|
|
return 0
|
|
|
|
|
|
def raise_func():
|
|
raise ValueError("Expected error")
|
|
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
|
|
@unittest.skipIf(
|
|
sys.version_info < (3, 0),
|
|
"Pytorch distributed rpc package " "does not support python2",
|
|
)
|
|
class RpcTest(RpcAgentTestFixture):
|
|
@dist_init
|
|
def test_worker_id(self):
|
|
n = self.rank + 1
|
|
peer_rank = n % self.world_size
|
|
self_worker_info = rpc.get_worker_info()
|
|
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
|
|
|
|
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
|
|
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
|
|
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
|
|
|
|
@dist_init
|
|
def test_get_worker_infos(self):
|
|
worker_infos = rpc.api._agent.get_worker_infos()
|
|
|
|
worker_names = {
|
|
worker_info.name for worker_info in worker_infos
|
|
}
|
|
expected_worker_names = {
|
|
"worker{}".format(rank) for rank in range(self.world_size)
|
|
}
|
|
self.assertEqual(worker_names, expected_worker_names)
|
|
|
|
worker_ids = {
|
|
worker_info.id for worker_info in worker_infos
|
|
}
|
|
expected_worker_ids = {
|
|
rank for rank in range(self.world_size)
|
|
}
|
|
self.assertEqual(worker_ids, expected_worker_ids)
|
|
|
|
@dist_init
|
|
def test_self_add(self):
|
|
self_worker_info = rpc.get_worker_info()
|
|
self_worker_name = "worker{}".format(self.rank)
|
|
fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
|
ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
|
|
|
@dist_init
|
|
def test_self_py_udf_remote(self):
|
|
self_worker_info = rpc.get_worker_info()
|
|
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3)
|
|
|
|
def _test_self_remote_rref_as_rpc_arg(self, dst):
|
|
self_worker_info = rpc.get_worker_info()
|
|
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
|
|
fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
|
|
ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1))
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_rpc_arg(self):
|
|
dst = "worker{}".format((self.rank + 1) % self.world_size)
|
|
self._test_self_remote_rref_as_rpc_arg(dst)
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_self_rpc_arg(self):
|
|
self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info())
|
|
|
|
def _test_self_remote_rref_as_remote_arg(self, dst):
|
|
self_worker_info = rpc.get_worker_info()
|
|
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
|
|
ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
|
|
self.assertEqual(ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_remote_arg(self):
|
|
dst = "worker{}".format((self.rank + 1) % self.world_size)
|
|
self._test_self_remote_rref_as_remote_arg(dst)
|
|
|
|
@dist_init
|
|
def test_self_remote_rref_as_self_remote_arg(self):
|
|
self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info())
|
|
|
|
@mock.patch.object(torch.distributed.autograd, "_init")
|
|
@mock.patch.object(torch.distributed.rpc.api, "_start_rpc_agent")
|
|
@dist_init(setup_rpc=False)
|
|
def test_register_rpc_backend_and_start_rpc_backend(
|
|
self, mock_rpc_agent, mock_dist_autograd_init
|
|
):
|
|
backend_name = "stub_backend"
|
|
|
|
backend = rpc.backend_registry.register_backend(
|
|
backend_name,
|
|
_stub_construct_rpc_backend_options_handler,
|
|
_stub_start_rpc_backend_handler,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "^RPC backend .+: already registered$"
|
|
):
|
|
backend = rpc.backend_registry.register_backend(
|
|
backend_name,
|
|
_stub_construct_rpc_backend_options_handler,
|
|
_stub_start_rpc_backend_handler,
|
|
)
|
|
|
|
rpc.init_rpc(
|
|
name="worker1",
|
|
backend=backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
@dist_init(setup_rpc=False)
|
|
def test_duplicate_name(self):
|
|
with self.assertRaisesRegex(RuntimeError, "is not unique"):
|
|
store, _, _ = next(torch.distributed.rendezvous(
|
|
self.init_method, rank=self.rank, world_size=self.world_size
|
|
))
|
|
rpc._init_rpc_backend(
|
|
backend=self.rpc_backend,
|
|
store=store,
|
|
name="duplicate_name",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
rpc.wait_all_workers()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_reinit(self):
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
# This is for the below `dist.barrier`.
|
|
# For `RpcAgent` other than `ProcessGroupAgent`,
|
|
# no `_default_pg` is initialized.
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method=self.init_method,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
)
|
|
# Wait for all init to complete.
|
|
dist.barrier()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
rpc.wait_all_workers()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_invalid_names(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
|
store, _, _ = next(torch.distributed.rendezvous(
|
|
self.init_method, rank=self.rank, world_size=self.world_size
|
|
))
|
|
rpc._init_rpc_backend(
|
|
backend=self.rpc_backend,
|
|
store=store,
|
|
name="abc*",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
base_file_name = self.file_name
|
|
# Use a different file path for FileStore to avoid rendezvous mismatch.
|
|
self.file_name = base_file_name + "1"
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
|
store, _, _ = next(torch.distributed.rendezvous(
|
|
self.init_method, rank=self.rank, world_size=self.world_size
|
|
))
|
|
rpc._init_rpc_backend(
|
|
backend=self.rpc_backend,
|
|
store=store,
|
|
name=" ",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
# Use a different file path for FileStore to avoid rendezvous mismatch.
|
|
self.file_name = base_file_name + "2"
|
|
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
|
|
store, _, _ = next(torch.distributed.rendezvous(
|
|
self.init_method, rank=self.rank, world_size=self.world_size
|
|
))
|
|
rpc._init_rpc_backend(
|
|
backend=self.rpc_backend,
|
|
store=store,
|
|
name="",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
# Use a different file path for FileStore to avoid rendezvous mismatch.
|
|
self.file_name = base_file_name + "3"
|
|
# If the number in the message does not match, it is likely that the
|
|
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
|
|
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
|
|
store, _, _ = next(torch.distributed.rendezvous(
|
|
self.init_method, rank=self.rank, world_size=self.world_size
|
|
))
|
|
rpc._init_rpc_backend(
|
|
backend=self.rpc_backend,
|
|
store=store,
|
|
name="".join(["a" for i in range(500)]),
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
from torch.distributed.rpc.api import _agent
|
|
self.assertEqual(_agent, None)
|
|
# wait_all_workers() should not do anything as _agent is None
|
|
rpc.wait_all_workers()
|
|
# We need this barrier here because although init_process_group is
|
|
# blocking, it does not guarantee that all ranks are done with
|
|
# initialization after the call. We did run into issues with it where
|
|
# rank 3 crashed with "connection closed by peer" RuntimeError, which is
|
|
# caused by other ranks exit before rank 3 is ready. This can be fixed
|
|
# by adding a collective call to sync all processes.
|
|
#
|
|
# We decided not fixing this issue in init_process_group because it
|
|
# would add extra overhead to the call, and normal use cases won't
|
|
# create a progress group and exit without doing anything. Hence, it is
|
|
# not worthy to introduce the overhead just for this test case.
|
|
|
|
@dist_init
|
|
def test_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_add_with_id(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
|
|
|
ret = rpc.rpc_sync(
|
|
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_scalar_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
|
|
)
|
|
self.assertEqual(ret, (torch.ones(n, n) + n))
|
|
|
|
@dist_init
|
|
def test_async_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
fut = rpc.rpc_async(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_nonzero(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
x = torch.ones(self.world_size, self.world_size)
|
|
x[self.rank][self.rank] = 0
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
|
|
self.assertEqual(ret, x.nonzero())
|
|
|
|
@dist_init
|
|
def test_multi_rpc(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
for i in range(20):
|
|
n = i + self.rank + 1
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_wait_all_workers(self):
|
|
# Initialize RPC.
|
|
rpc.init_rpc(
|
|
name="worker%d" % self.rank,
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
rpc.wait_all_workers()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
|
rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
|
|
# it's safe to call wait_all_workers() multiple times
|
|
rpc.wait_all_workers()
|
|
|
|
@dist_init
|
|
def test_expected_src(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
expected_src_rank = (self.rank - 1) % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
|
|
value = VALUE_FUTURE.result()
|
|
self.assertEqual(value, expected_src_rank)
|
|
|
|
@dist_init
|
|
def test_py_built_in(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
|
|
self.assertEqual(ret, min(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_user_defined(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
my_function,
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
|
)
|
|
self.assertEqual(ret, my_function(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_class_constructor(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
|
|
self.assertEqual(ret.a, n)
|
|
|
|
@dist_init
|
|
def test_py_class_instance_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
|
|
)
|
|
self.assertEqual(ret, MyClass(2).my_instance_method(n))
|
|
|
|
@dist_init
|
|
def test_py_class_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
|
|
)
|
|
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
|
|
|
|
@dist_init
|
|
def test_py_class_static_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
|
|
)
|
|
self.assertEqual(ret, MyClass.my_static_method(n + 10))
|
|
|
|
@dist_init
|
|
def test_py_multi_async_call(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
|
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
|
|
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
|
|
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
|
|
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_no_return_result(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
|
|
self.assertEqual(ret, no_result())
|
|
|
|
@dist_init
|
|
def test_py_tensors(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
my_tensor_function,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
|
|
|
|
@dist_init
|
|
def test_py_tensors_multi_async_call(self):
|
|
futs = []
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
for i in range(100):
|
|
fut = rpc.rpc_async(
|
|
"worker{}".format(dst_rank),
|
|
my_tensor_function,
|
|
args=(torch.ones(i, i), torch.ones(i, i)),
|
|
)
|
|
futs.append(fut)
|
|
|
|
j = 0
|
|
for fut in futs:
|
|
self.assertEqual(
|
|
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
|
)
|
|
j += 1
|
|
|
|
@dist_init
|
|
def test_py_tensors_in_container(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
a = [torch.ones(n, n), torch.ones(n, n)]
|
|
b = TensorClass(build_complex_tensors())
|
|
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
|
|
)
|
|
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
|
|
|
|
@dist_init
|
|
def test_py_nested_pickle(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
run_nested_pickle,
|
|
args=(MyPickleClass(), torch.ones(2, 2)),
|
|
)
|
|
|
|
m = MyPickleClass()
|
|
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
|
|
self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
|
|
|
|
@dist_init
|
|
def test_py_function_exception(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
with self.assertRaisesRegex(Exception, "TypeError"):
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
|
|
|
|
@dist_init
|
|
def test_py_raise_in_user_func(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
|
|
with self.assertRaisesRegex(Exception, "ValueError"):
|
|
fut.wait()
|
|
|
|
@dist_init
|
|
def test_nested_rpc(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
nested_rpc,
|
|
args=("worker{}".format(self.rank),),
|
|
)
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
|
|
|
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
|
import time
|
|
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
futs = []
|
|
tik = time.time()
|
|
for _ in range(repeat):
|
|
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
self.assertEqual(fut.wait(), 0)
|
|
tok = time.time()
|
|
print(
|
|
"Rank {} finished testing {} {} times in {} seconds.".format(
|
|
self.rank, f.__name__, repeat, tok - tik
|
|
)
|
|
)
|
|
|
|
@dist_init
|
|
def test_stress_light_rpc(self):
|
|
self._stress_test_rpc(light_rpc)
|
|
|
|
@dist_init
|
|
def test_stress_heavy_rpc(self):
|
|
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
|
|
|
|
@dist_init
|
|
def test_builtin_remote_ret(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_asymmetric_load_with_join(self):
|
|
"""Test graceful termination."""
|
|
# worker0 drives and waits for worker1 and worker2
|
|
# throughout the test.
|
|
if self.rank == 0:
|
|
assert self.world_size >= 3
|
|
|
|
num_repeat = 100
|
|
futs = []
|
|
|
|
# Phase 1: Only worker1 has workload.
|
|
dst = "worker1"
|
|
for _ in range(num_repeat):
|
|
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
fut.wait()
|
|
self.assertEqual(fut.wait(), 0)
|
|
|
|
# Phase 2: Only worker2 has workload.
|
|
# If join is not correctly implemented,
|
|
# worker2 should be closed by now.
|
|
dst = "worker2"
|
|
for _ in range(num_repeat):
|
|
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
fut.wait()
|
|
self.assertEqual(fut.wait(), 0)
|
|
|
|
def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
|
|
m = 10
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rrefs = []
|
|
expected = []
|
|
for i in range(m):
|
|
n = n + i
|
|
rrefs.append(
|
|
rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
fn,
|
|
args=args_fn(n),
|
|
kwargs=kwargs_fn(n),
|
|
)
|
|
)
|
|
expected.append(fn(*args_fn(n), **kwargs_fn(n)))
|
|
|
|
for i in range(m):
|
|
self.assertEqual(rrefs[i].to_here(), expected[i])
|
|
|
|
@dist_init
|
|
def test_multi_builtin_remote_ret(self):
|
|
def args_fn(n):
|
|
return (torch.ones(n, n), torch.ones(n, n))
|
|
|
|
self._test_multi_remote_call(torch.add, args_fn=args_fn)
|
|
|
|
@dist_init
|
|
def test_py_udf_remote(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
my_function,
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
|
)
|
|
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_multi_py_udf_remote(self):
|
|
def kwargs_fn(n):
|
|
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
|
|
|
|
self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
|
|
|
|
@dist_init
|
|
def test_py_rref_args(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
rref_c = rpc.remote(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
def test_py_rref_args_user_share(self):
|
|
n = self.rank + 1
|
|
owner_rank = n % self.world_size
|
|
user_rank = (n + 1) % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
|
|
)
|
|
rref_c = rpc.remote(
|
|
"worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
def test_py_rpc_rref_args(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
|
|
)
|
|
|
|
c = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
|
|
self.assertEqual(c, torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
def test_nested_remote(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_remote,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
|
|
|
|
@dist_init
|
|
def test_nested_rref(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref_of_rrefs = rpc.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
rrefs = rref_of_rrefs.to_here()
|
|
self.assertEqual(len(rrefs), 2)
|
|
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
|
|
|
|
@dist_init
|
|
def test_nested_rref_stress(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
all_rrefs = []
|
|
for _ in range(20):
|
|
all_rrefs.append(
|
|
rpc.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
)
|
|
|
|
for i in range(20):
|
|
rref_of_rrefs = all_rrefs[i]
|
|
rrefs = rref_of_rrefs.to_here()
|
|
self.assertEqual(len(rrefs), 2)
|
|
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
|
|
|
|
@dist_init
|
|
def test_multi_layer_nested_async_rpc(self):
|
|
# This test will exit right away, but there will be a chain of async
|
|
# RPCs. The termination algorithm should detect those messages properly.
|
|
# Otherwise, some peer could exit early, leaving others to timeout
|
|
# errors or connection closed errors.
|
|
ttl = 20
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
|
|
|
|
@dist_init
|
|
def test_remote_with_exception(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
# check ref to other workers
|
|
rref = rpc.remote("worker{}".format(dst_rank), raise_func)
|
|
with self.assertRaisesRegex(Exception, "ValueError"):
|
|
rref.to_here()
|
|
# check ref to itself
|
|
rref = rpc.remote("worker{}".format(self.rank), no_result, args=(10,))
|
|
with self.assertRaisesRegex(Exception, "TypeError"):
|
|
rref.to_here()
|
|
|
|
@dist_init
|
|
def test_rpc_return_rref(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank1),
|
|
rpc_return_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
|
|
|
|
@dist_init
|
|
def test_rref_forward_chain(self):
|
|
ttl = 8
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
|
|
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
|
|
|
|
for i in range(ttl):
|
|
self.assertEqual(len(ret_rref), 1)
|
|
ret_rref = ret_rref[0].to_here()
|
|
|
|
ret = ret_rref
|
|
self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
|
|
|
|
@dist_init
|
|
def test_local_rref_no_fork(self):
|
|
local_rref = RRef(35)
|
|
self.assertEqual(local_rref.local_value(), 35)
|
|
|
|
@dist_init
|
|
def test_return_local_rrefs(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
rref_list = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), get_rref_list, args=(
|
|
[1, 2, 3], ))
|
|
|
|
for rref in rref_list:
|
|
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, 10))
|
|
|
|
rets = [
|
|
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.get_value, rref))
|
|
for rref in rref_list]
|
|
|
|
self.assertEqual(rets, [11, 12, 13])
|
|
|
|
@dist_init
|
|
def test_owner_equality(self):
|
|
a = RRef(40)
|
|
b = RRef(50)
|
|
|
|
other_rank = (self.rank + 1) % self.world_size
|
|
other_a = rpc.remote(
|
|
"worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)
|
|
)
|
|
other_b = rpc.remote(
|
|
"worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)
|
|
)
|
|
other_a.to_here() # to ensure clean termination
|
|
other_b.to_here()
|
|
|
|
self.assertNotEqual(a.owner(), 23)
|
|
self.assertEqual(other_a.owner(), other_b.owner())
|
|
self.assertNotEqual(a.owner(), other_a.owner())
|
|
self.assertEqual(other_a.owner(), other_a.owner())
|
|
self.assertEqual(other_a.owner(), other_b.owner())
|
|
self.assertEqual(a.owner(), a.owner())
|
|
self.assertEqual(a.owner(), b.owner())
|
|
self.assertEqual(a.owner(), rpc.get_worker_info())
|
|
x = dict()
|
|
x[a.owner()] = a
|
|
x[other_a.owner()] = other_a
|
|
self.assertEqual(x[a.owner()], a)
|
|
self.assertEqual(x[b.owner()], a)
|
|
self.assertEqual(x[other_a.owner()], other_a)
|
|
self.assertEqual(x[other_b.owner()], other_a)
|
|
self.assertEqual(len(x), 2)
|
|
|
|
@dist_init
|
|
def test_pass_local_rrefs(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
dst_worker = "worker{}".format(dst_rank)
|
|
|
|
rref = RRef(40)
|
|
self.assertEqual(
|
|
rpc.rpc_sync(
|
|
dst_worker, add_rref_to_value, args=(rref, 50)), 90)
|
|
self.assertEqual(
|
|
rpc.rpc_async(
|
|
dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90)
|
|
self.assertEqual(
|
|
rpc.remote(
|
|
dst_worker,
|
|
add_rref_to_value,
|
|
args=(rref, 50)).to_here(), 90)
|
|
|
|
@dist_init
|
|
def test_remote_same_worker(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rref_b = rpc.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
rref_c = rpc.remote(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init(setup_rpc=True)
|
|
def test_call_method_on_rref(self):
|
|
"""
|
|
Tests that it is possible to call an instance method on a remote objet
|
|
by using rref.owner() as destination of the call.
|
|
"""
|
|
vals = [10, 2, 5, 7]
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
dst_worker = "worker{}".format(dst_rank)
|
|
|
|
# creates a remote object
|
|
rref = rpc.remote(dst_worker, MyClass, args=(vals[0], ))
|
|
|
|
# modifies state of the remote object
|
|
rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, vals[1]))
|
|
rpc.rpc_async(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, vals[2])).wait()
|
|
rpc.remote(rref.owner(), _call_method_on_rref, args=(
|
|
MyClass.increment_value, rref, vals[3])).to_here()
|
|
|
|
# queries state of the remote object
|
|
result = rpc.rpc_sync(dst_worker, _call_method_on_rref, args=(
|
|
MyClass.get_value, rref))
|
|
|
|
self.assertEqual(result, sum(vals))
|
|
|
|
def _test_rref_leak(self, ignore_leak=False):
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
# This is for the below `dist.barrier`.
|
|
# For `RpcAgent` other than `ProcessGroupAgent`,
|
|
# no `_default_pg` is initialized.
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method=self.init_method,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
)
|
|
# Wait for all init to complete.
|
|
dist.barrier()
|
|
|
|
rref = rpc.remote(
|
|
"worker{}".format((self.rank + 1) % self.world_size),
|
|
torch.add,
|
|
args=(torch.ones(2, 2), 1)
|
|
)
|
|
|
|
if ignore_leak:
|
|
import torch.distributed.rpc.api as api
|
|
api._ignore_rref_leak = True
|
|
|
|
rpc.wait_all_workers()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_rref_leak(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
|
|
self._test_rref_leak()
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_ignore_rref_leak(self):
|
|
self._test_rref_leak(ignore_leak=True)
|
|
|
|
@dist_init(setup_rpc=False)
|
|
def test_get_rpc_timeout(self):
|
|
timeout = timedelta(seconds=1)
|
|
|
|
# A new `RpcBackendOptions` is constructed
|
|
# when accessing `self.rpc_backend_options`.
|
|
rpc_backend_options = self.rpc_backend_options
|
|
rpc_backend_options.rpc_timeout = timeout
|
|
|
|
rpc.init_rpc(
|
|
name="worker{}".format(self.rank),
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=rpc_backend_options,
|
|
)
|
|
set_timeout = rpc.get_rpc_timeout()
|
|
self.assertEqual(timeout, set_timeout)
|
|
rpc.wait_all_workers()
|
|
|
|
@dist_init
|
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
|
def test_rpc_timeouts(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
rpc._set_rpc_timeout(timedelta(milliseconds=1))
|
|
# futures should time out and be marked with an exception indicating it as such.
|
|
futs = [rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()) for _ in range(10)]
|
|
for fut in futs:
|
|
with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"):
|
|
fut.wait()
|
|
|
|
# ensure that if a new timeout is set old futures don't time out but new ones do.
|
|
rpc._set_rpc_timeout(timedelta(seconds=200))
|
|
# create a longstanding RPC.
|
|
fut1 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
|
|
# now, set a short timeout.
|
|
rpc._set_rpc_timeout(timedelta(milliseconds=1))
|
|
# f2 should time out, f should not.
|
|
fut2 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1,))
|
|
with self.assertRaises(RuntimeError):
|
|
fut2.wait()
|
|
fut1.wait()
|
|
|
|
# future should run to completion if the timeout is zero.
|
|
rpc._set_rpc_timeout(timedelta(seconds=0))
|
|
rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()).wait()
|
|
|
|
# reset to default timeout so shutdown messages can process cleanly.
|
|
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT)
|
|
|
|
def test_requires_process_group_agent_decorator(self):
|
|
@requires_process_group_agent("test_func did not run")
|
|
def test_func():
|
|
return "expected result"
|
|
|
|
if dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP":
|
|
self.assertEqual(test_func(), "expected result")
|
|
|
|
def test_dist_init_decorator(self):
|
|
@dist_init(setup_rpc=False)
|
|
def test_func(self):
|
|
return "expected result"
|
|
|
|
self.assertEqual(test_func(self), "expected result")
|
|
|
|
@dist_init
|
|
def test_func(self):
|
|
return "expected result"
|
|
|
|
self.assertEqual(test_func(self), "expected result")
|
|
|
|
def test_use_rpc_pickler(self):
|
|
class TestPickler():
|
|
pass
|
|
test_pickler = TestPickler()
|
|
with _use_rpc_pickler(test_pickler):
|
|
self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
|
|
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
|