mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32633 There were 2 sources of current RPC agent. - One is in Python world, `torch.distributedrpc.api._agent`. - The other is in C++ world, `RpcAgent::defaultRpcAgent_` Setting Python `_agent` to `None`, does not necessarily reset the C++ `defaultRpcAgent_` to `nullptr`. i.e. ``` torch.distributedrpc.api._agent = None ``` does not translate to ``` RpcAgent::defaultRpcAgent_ = nullptr ``` This PR is to remove this ambiguity, and use the C++ pointer as source of truth. The solution is to leverage a pybind11 behavior that it implicitly casts C++ `shared_ptr<RpcAgent>(nullptr)` to Python `None`. ghstack-source-id: 97293315 Test Plan: ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_duplicate_name buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par -r test_process_group_debug_info ``` ``` buck test mode/dev-nosan //caffe2/torch/fb/distributed/pytorch/tests:test_remote_module buck test mode/dev-nosan //caffe2/torch/fb/distributed/modules/tests:test_sharded_embedding buck test mode/dev-nosan //caffe2/torch/fb/distributed/modules/tests:test_sharded_pairwise_attention_pooling buck test mode/dev-nosan //caffe2/torch/fb/distributed/pytorch/tests:test_rpc ``` Differential Revision: D5733066 fbshipit-source-id: b3e6032ee975f19ca556497edbbf40b517b25be8
112 lines
3.4 KiB
Python
112 lines
3.4 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import time
|
|
from functools import partial, wraps
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed.rpc as rpc
|
|
|
|
|
|
if not dist.is_available():
|
|
print("c10d not available, skipping tests")
|
|
sys.exit(0)
|
|
|
|
|
|
class TestConfig:
|
|
__slots__ = ["rpc_backend_name", "build_rpc_backend_options"]
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
assert len(args) == 0, "TestConfig only takes kwargs."
|
|
for k, v in kwargs.items():
|
|
setattr(self, k, v)
|
|
|
|
|
|
TEST_CONFIG = TestConfig()
|
|
INIT_METHOD_TEMPLATE = "file://{file_name}"
|
|
|
|
|
|
def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
|
|
"""
|
|
We use this decorator for setting up and tearing down state since
|
|
MultiProcessTestCase runs each `test*` method in a separate process and
|
|
each process just runs the `test*` method without actually calling
|
|
'setUp' and 'tearDown' methods of unittest.
|
|
"""
|
|
|
|
# If we use dist_init without arguments (ex: @dist_init), old_test_method is
|
|
# appropriately set and we return the wrapper appropriately. On the other
|
|
# hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
|
|
# old_test_method is None and we return a functools.partial which is the real
|
|
# decorator that is used and as a result we recursively call dist_init with
|
|
# old_test_method and the rest of the arguments appropriately set.
|
|
if old_test_method is None:
|
|
return partial(
|
|
dist_init,
|
|
setup_rpc=setup_rpc,
|
|
clean_shutdown=clean_shutdown,
|
|
)
|
|
|
|
@wraps(old_test_method)
|
|
def new_test_method(self, *arg, **kwargs):
|
|
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
|
|
# in tests.
|
|
import torch.distributed.rpc.api as api
|
|
api._ignore_rref_leak = False
|
|
|
|
self.worker_id = self.rank
|
|
|
|
if setup_rpc:
|
|
rpc.init_rpc(
|
|
name="worker%d" % self.rank,
|
|
backend=self.rpc_backend,
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
rpc_backend_options=self.rpc_backend_options,
|
|
)
|
|
|
|
return_value = old_test_method(self, *arg, **kwargs)
|
|
|
|
if setup_rpc:
|
|
rpc.shutdown(graceful=clean_shutdown)
|
|
|
|
return return_value
|
|
|
|
return new_test_method
|
|
|
|
|
|
# Set PROCESS_GROUP as the default RPC backend.
|
|
TEST_CONFIG.rpc_backend_name = "PROCESS_GROUP"
|
|
TEST_CONFIG.build_rpc_backend_options = lambda test_object: rpc.backend_registry.construct_rpc_backend_options(
|
|
test_object.rpc_backend,
|
|
init_method=test_object.init_method,
|
|
# Some tests need additional threads (ex: test_trainer_ps)
|
|
num_send_recv_threads=8,
|
|
)
|
|
|
|
def noop():
|
|
pass
|
|
|
|
def wait_until_node_failure(rank):
|
|
'''
|
|
Loops until an RPC to the given rank fails. This is used to
|
|
indicate that the node has failed in unit tests.
|
|
'''
|
|
while True:
|
|
try:
|
|
rpc.rpc_sync("worker{}".format(rank), noop, args=())
|
|
time.sleep(0.5)
|
|
except Exception:
|
|
break
|
|
|
|
def initialize_pg(init_method, rank, world_size):
|
|
# This is for tests using `dist.barrier`.
|
|
# For `RpcAgent` other than `ProcessGroupAgent`,
|
|
# no `_default_pg` is initialized.
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
init_method=init_method,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|