pytorch/torch/distributed/rpc/rref_proxy.py
Rohan Varma d64184ef4c [RPC] Support timeout for RRef proxy functions (#50499)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50499

Adds a timeout API to the following functions:
```
rref.rpc_sync()
rref.rpc_async()
rref.remote()
```
so that RPCs initiated by these proxy calls can be appropriately timed out similar to the regular RPC APIs. Timeouts are supported in the following use cases:

1. rpc.remote finishes in time and successfully, but function run by rref.rpc_async() is slow and times out. Timeout error will be raised
2. rref.rpc_async() function is fast, but rpc.remote() is slow/hanging. Then when rref.rpc_async() is called, it will still timeout with the passed in timeout (and won't block for the rpc.remote() to succeed, which is what happens currently). Although, the timeout will occur during the future creation itself (and not the wait) since it calls `rref._get_type` which blocks. We can consider making this nonblocking by modifying rref._get_type to return a future, although that is likely a larger change.

Test Plan: Added UT

Reviewed By: wanchaol

Differential Revision: D25897495

fbshipit-source-id: f9ad5b8f75121f50537677056a5ab16cf262847e
2021-01-15 13:23:23 -08:00

47 lines
1.6 KiB
Python

from functools import partial
from . import functions
import torch
from .constants import UNSET_RPC_TIMEOUT
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
def _local_invoke_async_execution(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
# Since rref._get_type can potentially issue an RPC, it should respect the
# passed in timeout here.
rref_type = rref._get_type(timeout=timeout)
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
return rpc_api(
rref.owner(),
_invoke_func,
args=(rref, func_name, args, kwargs),
timeout=timeout
)
# This class manages proxied RPC API calls for RRefs. It is entirely used from
# C++ (see python_rpc_handler.cpp).
class RRefProxy:
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
self.rref = rref
self.rpc_api = rpc_api
self.rpc_timeout = timeout
def __getattr__(self, func_name):
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)