mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[distributed] Make rref_proxy._invoke_rpc trully async when needed. (#70206)
Summary: From https://github.com/pytorch/pytorch/issues/67626: RRefProxy (rref.rpc_async, rref.rpc_sync, rref.remote) currently uses a blocking RPC call to the owner This is done by chaining async calls. In the sync case we wait on the resulting Future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/70206 Test Plan: I ran rpc_tests using tensorpipe_rpc_agent_test_fixture.py and had to adjust test_rref_proxy_timeout to the new behavior. I ran into test_tensorpipe_set_default_timeout failing due to the timeout being too small. Doesn't look related to this change. mrshenli Fixes https://github.com/pytorch/pytorch/issues/67626 cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Reviewed By: pritamdamania87 Differential Revision: D33243348 Pulled By: kumpera fbshipit-source-id: e1e8c34bb3d170407c0a793e2e585357f905d3c6
This commit is contained in:
parent
6ea546ce11
commit
1ad5a7ceea
|
|
@ -1,9 +1,11 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from . import functions
|
from . import functions
|
||||||
|
from . import rpc_async
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .constants import UNSET_RPC_TIMEOUT
|
from .constants import UNSET_RPC_TIMEOUT
|
||||||
|
from torch.futures import Future
|
||||||
|
|
||||||
def _local_invoke(rref, func_name, args, kwargs):
|
def _local_invoke(rref, func_name, args, kwargs):
|
||||||
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
||||||
|
|
@ -13,9 +15,8 @@ def _local_invoke_async_execution(rref, func_name, args, kwargs):
|
||||||
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
||||||
|
|
||||||
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
|
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
|
||||||
# Since rref._get_type can potentially issue an RPC, it should respect the
|
def _rref_type_cont(rref_fut):
|
||||||
# passed in timeout here.
|
rref_type = rref_fut.value()
|
||||||
rref_type = rref._get_type(timeout=timeout)
|
|
||||||
|
|
||||||
_invoke_func = _local_invoke
|
_invoke_func = _local_invoke
|
||||||
# Bypass ScriptModules when checking for async function attribute.
|
# Bypass ScriptModules when checking for async function attribute.
|
||||||
|
|
@ -34,6 +35,33 @@ def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rref_fut = rref._get_type(timeout=timeout, blocking=False)
|
||||||
|
|
||||||
|
if rpc_api != rpc_async:
|
||||||
|
rref_fut.wait()
|
||||||
|
return _rref_type_cont(rref_fut)
|
||||||
|
else:
|
||||||
|
# A little explanation on this.
|
||||||
|
# rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
|
||||||
|
# Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
|
||||||
|
# To address that, we return a Future that is completed with the result of the async call.
|
||||||
|
result: Future = Future()
|
||||||
|
|
||||||
|
def _wrap_rref_type_cont(fut):
|
||||||
|
try:
|
||||||
|
_rref_type_cont(fut).then(_complete_op)
|
||||||
|
except BaseException as ex:
|
||||||
|
result.set_exception(ex)
|
||||||
|
|
||||||
|
def _complete_op(fut):
|
||||||
|
try:
|
||||||
|
result.set_result(fut.value())
|
||||||
|
except BaseException as ex:
|
||||||
|
result.set_exception(ex)
|
||||||
|
|
||||||
|
rref_fut.then(lambda fut: _wrap_rref_type_cont(fut))
|
||||||
|
return result
|
||||||
|
|
||||||
# This class manages proxied RPC API calls for RRefs. It is entirely used from
|
# This class manages proxied RPC API calls for RRefs. It is entirely used from
|
||||||
# C++ (see python_rpc_handler.cpp).
|
# C++ (see python_rpc_handler.cpp).
|
||||||
class RRefProxy:
|
class RRefProxy:
|
||||||
|
|
|
||||||
|
|
@ -1144,7 +1144,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
||||||
rref.rpc_sync().non_exist()
|
rref.rpc_sync().non_exist()
|
||||||
|
|
||||||
with self.assertRaisesRegex(AttributeError, msg):
|
with self.assertRaisesRegex(AttributeError, msg):
|
||||||
rref.rpc_async().non_exist()
|
rref.rpc_async().non_exist().wait()
|
||||||
|
|
||||||
with self.assertRaisesRegex(AttributeError, msg):
|
with self.assertRaisesRegex(AttributeError, msg):
|
||||||
rref.remote().non_exist()
|
rref.remote().non_exist()
|
||||||
|
|
@ -4956,7 +4956,10 @@ class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon):
|
||||||
# which blocks on the RRef being created on owner node, until the
|
# which blocks on the RRef being created on owner node, until the
|
||||||
# specified timeout.
|
# specified timeout.
|
||||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||||
rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2))
|
result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2))
|
||||||
|
# rpc_async returns immediately and surface a timeout through wait()
|
||||||
|
if rref_api == slow_rref.rpc_async:
|
||||||
|
result.wait()
|
||||||
|
|
||||||
# FIXME We wait until the remote completed creating the OwnerRRef
|
# FIXME We wait until the remote completed creating the OwnerRRef
|
||||||
# because there's currently a race if we shut down RPC before that.
|
# because there's currently a race if we shut down RPC before that.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user