mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36619 With this PR, applications no longer need to create dedicated helpers to run functions on the object referenced by an RRef. Instead, `rref.rpc_sync().some_func()` will use `rpc_sync` to run `some_func` on the owner of the RRef using the object referenced by the RRef. Similar helpers for `rref.rpc_async().some_func()` and `rref.remote().some_func()` are also added. An alternative design is to expose PyRRef as RRefBase and then implement everything in a new Python RRef class. However, the RRef class cannot directly inherit from PyRRef/RRefBase, otherwise we will need to let pyRemote* C++ functions to load RRef from Python and return an RRef instance. It is possible to let RRef hold a instance of PyRRef instead of inherit from it, but this does not look like a elegant design, as we will have RRef holding PyRRef and PyRRef holding the C++ RRef. Another alternative is to use dynamic method loading, by installing member methods to PyRRef instances. However, this would require different solutions to handle RRef(data) and rpc.remote(...). Base on the above thinking, we decided to go with the current implementation for simplicity and we can also keep all RRef-related APIs in one place. Test Plan: Imported from OSS Differential Revision: D21028333 Pulled By: mrshenli fbshipit-source-id: fe90f56ef7183d18874e357900093755e1601eb4
23 lines
552 B
Python
23 lines
552 B
Python
from functools import partial
|
|
|
|
|
|
def _local_invoke(rref, func_name, args, kwargs):
|
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
|
|
|
|
def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs):
|
|
return rpc_api(
|
|
rref.owner(),
|
|
_local_invoke,
|
|
args=(rref, func_name, args, kwargs)
|
|
)
|
|
|
|
|
|
class RRefProxy:
|
|
def __init__(self, rref, rpc_api):
|
|
self.rref = rref
|
|
self.rpc_api = rpc_api
|
|
|
|
def __getattr__(self, func_name):
|
|
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name)
|