mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50499 Adds a timeout API to the following functions: ``` rref.rpc_sync() rref.rpc_async() rref.remote() ``` so that RPCs initiated by these proxy calls can be appropriately timed out similar to the regular RPC APIs. Timeouts are supported in the following use cases: 1. rpc.remote finishes in time and successfully, but function run by rref.rpc_async() is slow and times out. Timeout error will be raised 2. rref.rpc_async() function is fast, but rpc.remote() is slow/hanging. Then when rref.rpc_async() is called, it will still timeout with the passed in timeout (and won't block for the rpc.remote() to succeed, which is what happens currently). Although, the timeout will occur during the future creation itself (and not the wait) since it calls `rref._get_type` which blocks. We can consider making this nonblocking by modifying rref._get_type to return a future, although that is likely a larger change. Test Plan: Added UT Reviewed By: wanchaol Differential Revision: D25897495 fbshipit-source-id: f9ad5b8f75121f50537677056a5ab16cf262847e
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from functools import partial
|
|
|
|
from . import functions
|
|
|
|
import torch
|
|
from .constants import UNSET_RPC_TIMEOUT
|
|
|
|
def _local_invoke(rref, func_name, args, kwargs):
|
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
|
|
@functions.async_execution
|
|
def _local_invoke_async_execution(rref, func_name, args, kwargs):
|
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
|
|
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
|
|
# Since rref._get_type can potentially issue an RPC, it should respect the
|
|
# passed in timeout here.
|
|
rref_type = rref._get_type(timeout=timeout)
|
|
|
|
_invoke_func = _local_invoke
|
|
# Bypass ScriptModules when checking for async function attribute.
|
|
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
|
|
rref_type, torch._C.ScriptModule
|
|
)
|
|
if not bypass_type:
|
|
func = getattr(rref_type, func_name)
|
|
if hasattr(func, "_wrapped_async_rpc_function"):
|
|
_invoke_func = _local_invoke_async_execution
|
|
|
|
return rpc_api(
|
|
rref.owner(),
|
|
_invoke_func,
|
|
args=(rref, func_name, args, kwargs),
|
|
timeout=timeout
|
|
)
|
|
|
|
# This class manages proxied RPC API calls for RRefs. It is entirely used from
|
|
# C++ (see python_rpc_handler.cpp).
|
|
class RRefProxy:
|
|
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
|
|
self.rref = rref
|
|
self.rpc_api = rpc_api
|
|
self.rpc_timeout = timeout
|
|
|
|
def __getattr__(self, func_name):
|
|
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)
|