pytorch/torch/distributed/rpc/rref_proxy.py
Aaron Gokaslan b7b2178204 [BE]: Remove useless lambdas (#113602)
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
2023-11-14 20:06:48 +00:00

75 lines
2.6 KiB
Python

from functools import partial
from . import functions
from . import rpc_async
import torch
from .constants import UNSET_RPC_TIMEOUT
from torch.futures import Future
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
def _local_invoke_async_execution(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
def _rref_type_cont(rref_fut):
rref_type = rref_fut.value()
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
return rpc_api(
rref.owner(),
_invoke_func,
args=(rref, func_name, args, kwargs),
timeout=timeout
)
rref_fut = rref._get_type(timeout=timeout, blocking=False)
if rpc_api != rpc_async:
rref_fut.wait()
return _rref_type_cont(rref_fut)
else:
# A little explanation on this.
# rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
# Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
# To address that, we return a Future that is completed with the result of the async call.
result: Future = Future()
def _wrap_rref_type_cont(fut):
try:
_rref_type_cont(fut).then(_complete_op)
except BaseException as ex:
result.set_exception(ex)
def _complete_op(fut):
try:
result.set_result(fut.value())
except BaseException as ex:
result.set_exception(ex)
rref_fut.then(_wrap_rref_type_cont)
return result
# This class manages proxied RPC API calls for RRefs. It is entirely used from
# C++ (see python_rpc_handler.cpp).
class RRefProxy:
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
self.rref = rref
self.rpc_api = rpc_api
self.rpc_timeout = timeout
def __getattr__(self, func_name):
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)