mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is a new version of #15648 based on the latest master branch. Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR. In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.) Fixes https://github.com/pytorch/pytorch/issues/71105 @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797 Approved by: https://github.com/ezyang
167 lines
7.1 KiB
Python
167 lines
7.1 KiB
Python
import functools
|
|
|
|
|
|
def async_execution(fn):
|
|
r"""
|
|
A decorator for a function indicating that the return value of the function
|
|
is guaranteed to be a :class:`~torch.futures.Future` object and this
|
|
function can run asynchronously on the RPC callee. More specifically, the
|
|
callee extracts the :class:`~torch.futures.Future` returned by the wrapped
|
|
function and installs subsequent processing steps as a callback to that
|
|
:class:`~torch.futures.Future`. The installed callback will read the value
|
|
from the :class:`~torch.futures.Future` when completed and send the
|
|
value back as the RPC response. That also means the returned
|
|
:class:`~torch.futures.Future` only exists on the callee side and is never
|
|
sent through RPC. This decorator is useful when the wrapped function's
|
|
(``fn``) execution needs to pause and resume due to, e.g., containing
|
|
:meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
|
|
|
|
.. note:: To enable asynchronous execution, applications must pass the
|
|
function object returned by this decorator to RPC APIs. If RPC detected
|
|
attributes installed by this decorator, it knows that this function
|
|
returns a ``Future`` object and will handle that accordingly.
|
|
However, this does not mean this decorator has to be outmost one when
|
|
defining a function. For example, when combined with ``@staticmethod``
|
|
or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
|
|
inner decorator to allow the target function be recognized as a static
|
|
or class function. This target function can still execute asynchronously
|
|
because, when accessed, the static or class method preserves attributes
|
|
installed by ``@rpc.functions.async_execution``.
|
|
|
|
|
|
Example::
|
|
The returned :class:`~torch.futures.Future` object can come from
|
|
:meth:`~torch.distributed.rpc.rpc_async`,
|
|
:meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
|
|
constructor. The example below shows directly using the
|
|
:class:`~torch.futures.Future` returned by
|
|
:meth:`~torch.futures.Future.then`.
|
|
|
|
>>> from torch.distributed import rpc
|
|
>>>
|
|
>>> # omitting setup and shutdown RPC
|
|
>>>
|
|
>>> # On all workers
|
|
>>> @rpc.functions.async_execution
|
|
>>> def async_add_chained(to, x, y, z):
|
|
>>> # This function runs on "worker1" and returns immediately when
|
|
>>> # the callback is installed through the `then(cb)` API. In the
|
|
>>> # mean time, the `rpc_async` to "worker2" can run concurrently.
|
|
>>> # When the return value of that `rpc_async` arrives at
|
|
>>> # "worker1", "worker1" will run the lambda function accordingly
|
|
>>> # and set the value for the previously returned `Future`, which
|
|
>>> # will then trigger RPC to send the result back to "worker0".
|
|
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
|
|
>>> lambda fut: fut.wait() + z
|
|
>>> )
|
|
>>>
|
|
>>> # On worker0
|
|
>>> # xdoctest: +SKIP
|
|
>>> ret = rpc.rpc_sync(
|
|
>>> "worker1",
|
|
>>> async_add_chained,
|
|
>>> args=("worker2", torch.ones(2), 1, 1)
|
|
>>> )
|
|
>>> print(ret) # prints tensor([3., 3.])
|
|
|
|
When combined with TorchScript decorators, this decorator must be the
|
|
outmost one.
|
|
|
|
>>> from torch import Tensor
|
|
>>> from torch.futures import Future
|
|
>>> from torch.distributed import rpc
|
|
>>>
|
|
>>> # omitting setup and shutdown RPC
|
|
>>>
|
|
>>> # On all workers
|
|
>>> @torch.jit.script
|
|
>>> def script_add(x: Tensor, y: Tensor) -> Tensor:
|
|
>>> return x + y
|
|
>>>
|
|
>>> @rpc.functions.async_execution
|
|
>>> @torch.jit.script
|
|
>>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
|
|
>>> return rpc.rpc_async(to, script_add, (x, y))
|
|
>>>
|
|
>>> # On worker0
|
|
>>> ret = rpc.rpc_sync(
|
|
>>> "worker1",
|
|
>>> async_add,
|
|
>>> args=("worker2", torch.ones(2), 1)
|
|
>>> )
|
|
>>> print(ret) # prints tensor([2., 2.])
|
|
|
|
When combined with static or class method, this decorator must be the
|
|
inner one.
|
|
|
|
>>> from torch.distributed import rpc
|
|
>>>
|
|
>>> # omitting setup and shutdown RPC
|
|
>>>
|
|
>>> # On all workers
|
|
>>> class AsyncExecutionClass:
|
|
>>>
|
|
>>> @staticmethod
|
|
>>> @rpc.functions.async_execution
|
|
>>> def static_async_add(to, x, y, z):
|
|
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
|
|
>>> lambda fut: fut.wait() + z
|
|
>>> )
|
|
>>>
|
|
>>> @classmethod
|
|
>>> @rpc.functions.async_execution
|
|
>>> def class_async_add(cls, to, x, y, z):
|
|
>>> ret_fut = torch.futures.Future()
|
|
>>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
|
|
>>> lambda fut: ret_fut.set_result(fut.wait() + z)
|
|
>>> )
|
|
>>> return ret_fut
|
|
>>>
|
|
>>> @rpc.functions.async_execution
|
|
>>> def bound_async_add(self, to, x, y, z):
|
|
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
|
|
>>> lambda fut: fut.wait() + z
|
|
>>> )
|
|
>>>
|
|
>>> # On worker0
|
|
>>> ret = rpc.rpc_sync(
|
|
>>> "worker1",
|
|
>>> AsyncExecutionClass.static_async_add,
|
|
>>> args=("worker2", torch.ones(2), 1, 2)
|
|
>>> )
|
|
>>> print(ret) # prints tensor([4., 4.])
|
|
>>>
|
|
>>> ret = rpc.rpc_sync(
|
|
>>> "worker1",
|
|
>>> AsyncExecutionClass.class_async_add,
|
|
>>> args=("worker2", torch.ones(2), 1, 2)
|
|
>>> )
|
|
>>> print(ret) # prints tensor([4., 4.])
|
|
|
|
This decorator also works with RRef helpers, i.e., .
|
|
:meth:`torch.distributed.rpc.RRef.rpc_sync`,
|
|
:meth:`torch.distributed.rpc.RRef.rpc_async`, and
|
|
:meth:`torch.distributed.rpc.RRef.remote`.
|
|
|
|
>>> from torch.distributed import rpc
|
|
>>>
|
|
>>> # reuse the AsyncExecutionClass class above
|
|
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
|
|
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
|
|
>>> print(ret) # prints tensor([4., 4.])
|
|
>>>
|
|
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
|
|
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
|
|
>>> print(ret) # prints tensor([4., 4.])
|
|
>>>
|
|
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
|
|
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
|
|
>>> print(ret) # prints tensor([4., 4.])
|
|
"""
|
|
@functools.wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
# Can't declare and use attributes of function objects (mypy#2087)
|
|
wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
|
|
return wrapper
|