from __future__ import absolute_import, division, print_function, unicode_literals import time from functools import partial, wraps import torch.distributed as dist import torch.distributed.rpc as rpc if not dist.is_available(): print("c10d not available, skipping tests") sys.exit(0) class TestConfig: __slots__ = ["rpc_backend_name", "build_rpc_backend_options"] def __init__(self, *args, **kwargs): assert len(args) == 0, "TestConfig only takes kwargs." for k, v in kwargs.items(): setattr(self, k, v) TEST_CONFIG = TestConfig() INIT_METHOD_TEMPLATE = "file://{file_name}" def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True): """ We use this decorator for setting up and tearing down state since MultiProcessTestCase runs each `test*` method in a separate process and each process just runs the `test*` method without actually calling 'setUp' and 'tearDown' methods of unittest. """ # If we use dist_init without arguments (ex: @dist_init), old_test_method is # appropriately set and we return the wrapper appropriately. On the other # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)), # old_test_method is None and we return a functools.partial which is the real # decorator that is used and as a result we recursively call dist_init with # old_test_method and the rest of the arguments appropriately set. if old_test_method is None: return partial( dist_init, setup_rpc=setup_rpc, clean_shutdown=clean_shutdown, ) @wraps(old_test_method) def new_test_method(self, *arg, **kwargs): # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted # in tests. import torch.distributed.rpc.api as api api._ignore_rref_leak = False self.worker_id = self.rank if setup_rpc: rpc.init_rpc( name="worker%d" % self.rank, backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) return_value = old_test_method(self, *arg, **kwargs) if setup_rpc: rpc.shutdown(graceful=clean_shutdown) return return_value return new_test_method # Set PROCESS_GROUP as the default RPC backend. TEST_CONFIG.rpc_backend_name = "PROCESS_GROUP" TEST_CONFIG.build_rpc_backend_options = lambda test_object: rpc.backend_registry.construct_rpc_backend_options( test_object.rpc_backend, init_method=test_object.init_method, # Some tests need additional threads (ex: test_trainer_ps) num_send_recv_threads=8, ) def noop(): pass def wait_until_node_failure(rank): ''' Loops until an RPC to the given rank fails. This is used to indicate that the node has failed in unit tests. ''' while True: try: rpc.rpc_sync("worker{}".format(rank), noop, args=()) time.sleep(0.5) except Exception: break def initialize_pg(init_method, rank, world_size): # This is for tests using `dist.barrier`. # For `RpcAgent` other than `ProcessGroupAgent`, # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", init_method=init_method, rank=rank, world_size=world_size, )