pytorch/torch/testing/_internal/dist_utils.py
Shihao Xu 5c8535d5b0 Make C++ RpcAgent::currentRPCAgent_ the source of truth of current RPC Agent (#32633)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32633

There were 2 sources of current RPC agent.

- One is in Python world, `torch.distributedrpc.api._agent`.
- The other is in C++ world, `RpcAgent::defaultRpcAgent_`

Setting Python `_agent` to `None`, does not necessarily reset the C++ `defaultRpcAgent_` to `nullptr`.

i.e.
```
 torch.distributedrpc.api._agent = None
```
does not translate to
```
RpcAgent::defaultRpcAgent_ = nullptr
```

This PR is to remove this ambiguity, and use the C++ pointer as source of truth.

The solution is to leverage a pybind11 behavior that it implicitly casts C++ `shared_ptr<RpcAgent>(nullptr)` to Python `None`.
ghstack-source-id: 97293315

Test Plan:
```
buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_duplicate_name

buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork

buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par -r test_process_group_debug_info
```

```
buck test mode/dev-nosan //caffe2/torch/fb/distributed/pytorch/tests:test_remote_module

buck test mode/dev-nosan //caffe2/torch/fb/distributed/modules/tests:test_sharded_embedding

buck test mode/dev-nosan //caffe2/torch/fb/distributed/modules/tests:test_sharded_pairwise_attention_pooling

buck test mode/dev-nosan //caffe2/torch/fb/distributed/pytorch/tests:test_rpc
```

Differential Revision: D5733066

fbshipit-source-id: b3e6032ee975f19ca556497edbbf40b517b25be8
2020-01-27 19:34:12 -08:00

112 lines
3.4 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import time
from functools import partial, wraps
import torch.distributed as dist
import torch.distributed.rpc as rpc
if not dist.is_available():
print("c10d not available, skipping tests")
sys.exit(0)
class TestConfig:
__slots__ = ["rpc_backend_name", "build_rpc_backend_options"]
def __init__(self, *args, **kwargs):
assert len(args) == 0, "TestConfig only takes kwargs."
for k, v in kwargs.items():
setattr(self, k, v)
TEST_CONFIG = TestConfig()
INIT_METHOD_TEMPLATE = "file://{file_name}"
def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
"""
# If we use dist_init without arguments (ex: @dist_init), old_test_method is
# appropriately set and we return the wrapper appropriately. On the other
# hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
# old_test_method is None and we return a functools.partial which is the real
# decorator that is used and as a result we recursively call dist_init with
# old_test_method and the rest of the arguments appropriately set.
if old_test_method is None:
return partial(
dist_init,
setup_rpc=setup_rpc,
clean_shutdown=clean_shutdown,
)
@wraps(old_test_method)
def new_test_method(self, *arg, **kwargs):
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
# in tests.
import torch.distributed.rpc.api as api
api._ignore_rref_leak = False
self.worker_id = self.rank
if setup_rpc:
rpc.init_rpc(
name="worker%d" % self.rank,
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
return_value = old_test_method(self, *arg, **kwargs)
if setup_rpc:
rpc.shutdown(graceful=clean_shutdown)
return return_value
return new_test_method
# Set PROCESS_GROUP as the default RPC backend.
TEST_CONFIG.rpc_backend_name = "PROCESS_GROUP"
TEST_CONFIG.build_rpc_backend_options = lambda test_object: rpc.backend_registry.construct_rpc_backend_options(
test_object.rpc_backend,
init_method=test_object.init_method,
# Some tests need additional threads (ex: test_trainer_ps)
num_send_recv_threads=8,
)
def noop():
pass
def wait_until_node_failure(rank):
'''
Loops until an RPC to the given rank fails. This is used to
indicate that the node has failed in unit tests.
'''
while True:
try:
rpc.rpc_sync("worker{}".format(rank), noop, args=())
time.sleep(0.5)
except Exception:
break
def initialize_pg(init_method, rank, world_size):
# This is for tests using `dist.barrier`.
# For `RpcAgent` other than `ProcessGroupAgent`,
# no `_default_pg` is initialized.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=rank,
world_size=world_size,
)