pytorch/torch/_awaits/__init__.py
Ivan Kobzarev 2fc73622f8 [jit] Support Awaitable type (#90863)
We want to make TorchRec sharded models TorchScriptable.

TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212).
In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W.

At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization.

To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`.
In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`.

Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`.
Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`.

The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called.

For example (more examples in this PR `test/jit/test_await.py`)
```
      def delayed(z: Tensor) -> Tensor:
          return Tensor * 3

      @torch.jit.script
      def fn(x: Tensor):
          aw: Await[int] = torch.jit._awaitable(delayed, 99)
          a = torch.eye(2)
          b = torch.jit._awaitable_wait(aw)
          return a + b + x
```

Functions semantics:

`_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]`

Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call.

`_awaitable_wait(Await[W]) -> W`
Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first.

`_awaitable_nowait(W) -> Await[W]`

Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases.

Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863
Approved by: https://github.com/davidberard98
2023-01-30 17:38:59 +00:00

55 lines
1.6 KiB
Python

from __future__ import annotations
from typing import cast, Callable, Generic, Type, TypeVar
import torch
__all__ = ['Await']
W = TypeVar("W")
class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
pass
class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
r"""
Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
Torch scriptable manipulations:
``torch.jit._awaitable(func, *args)``
Creates ``Await[W]`` object, where W is return type of func.
Returns:
``torch.jit._awaitable_wait(Await[W])``
Returns the result of the function, specified at ``_awaitable``, with specified arguments.
Returns:
The result of type ``W`` of the function call. The result is owned by ``Await[W]``
and returned on all following ``_awaitable_wait`` calls.
``torch.jit._awaitable_nowait(W)``
Returns:
Trivial ``Await[W]`` with specified result.
Only in eager mode:
``fn() -> Callable[Tuple[Any], W]``
Returns:
Specified at ``_awaitable`` python function ``func``.
``args() -> Tuple[Any]``
Returns:
Specified at ``_awaitable`` python args.
``is_nowait() -> _bool``
Returns:
``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
``_awaitable_wait()`` call will be transparently added.
"""
pass