pytorch/torch/distributed/rpc/rref_proxy.py
Rodrigo Kumpera 1ad5a7ceea [distributed] Make rref_proxy._invoke_rpc trully async when needed. (#70206)
Summary:
From https://github.com/pytorch/pytorch/issues/67626: RRefProxy (rref.rpc_async, rref.rpc_sync, rref.remote) currently uses a blocking RPC call to the owner

This is done by chaining async calls. In the sync case we wait on the
resulting Future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70206

Test Plan:
I ran rpc_tests using tensorpipe_rpc_agent_test_fixture.py and had to
adjust test_rref_proxy_timeout to the new behavior.

I ran into test_tensorpipe_set_default_timeout failing due to the
timeout being too small. Doesn't look related to this change.
mrshenli
Fixes https://github.com/pytorch/pytorch/issues/67626

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Reviewed By: pritamdamania87

Differential Revision: D33243348

Pulled By: kumpera

fbshipit-source-id: e1e8c34bb3d170407c0a793e2e585357f905d3c6
2022-01-19 15:23:20 -08:00

75 lines
2.6 KiB
Python

from functools import partial
from . import functions
from . import rpc_async
import torch
from .constants import UNSET_RPC_TIMEOUT
from torch.futures import Future
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
def _local_invoke_async_execution(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
def _rref_type_cont(rref_fut):
rref_type = rref_fut.value()
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
return rpc_api(
rref.owner(),
_invoke_func,
args=(rref, func_name, args, kwargs),
timeout=timeout
)
rref_fut = rref._get_type(timeout=timeout, blocking=False)
if rpc_api != rpc_async:
rref_fut.wait()
return _rref_type_cont(rref_fut)
else:
# A little explanation on this.
# rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
# Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
# To address that, we return a Future that is completed with the result of the async call.
result: Future = Future()
def _wrap_rref_type_cont(fut):
try:
_rref_type_cont(fut).then(_complete_op)
except BaseException as ex:
result.set_exception(ex)
def _complete_op(fut):
try:
result.set_result(fut.value())
except BaseException as ex:
result.set_exception(ex)
rref_fut.then(lambda fut: _wrap_rref_type_cont(fut))
return result
# This class manages proxied RPC API calls for RRefs. It is entirely used from
# C++ (see python_rpc_handler.cpp).
class RRefProxy:
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
self.rref = rref
self.rpc_api = rpc_api
self.rpc_timeout = timeout
def __getattr__(self, func_name):
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)