mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
We want to make TorchRec sharded models TorchScriptable. TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212). In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W. At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization. To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`. In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`. Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`. Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`. The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called. For example (more examples in this PR `test/jit/test_await.py`) ``` def delayed(z: Tensor) -> Tensor: return Tensor * 3 @torch.jit.script def fn(x: Tensor): aw: Await[int] = torch.jit._awaitable(delayed, 99) a = torch.eye(2) b = torch.jit._awaitable_wait(aw) return a + b + x ``` Functions semantics: `_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]` Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call. `_awaitable_wait(Await[W]) -> W` Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first. `_awaitable_nowait(W) -> Await[W]` Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases. Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863 Approved by: https://github.com/davidberard98
55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import cast, Callable, Generic, Type, TypeVar
|
|
|
|
import torch
|
|
|
|
__all__ = ['Await']
|
|
|
|
W = TypeVar("W")
|
|
|
|
class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
|
|
pass
|
|
|
|
class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
|
|
r"""
|
|
Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
|
|
of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
|
|
``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
|
|
|
|
Torch scriptable manipulations:
|
|
``torch.jit._awaitable(func, *args)``
|
|
Creates ``Await[W]`` object, where W is return type of func.
|
|
|
|
Returns:
|
|
``torch.jit._awaitable_wait(Await[W])``
|
|
Returns the result of the function, specified at ``_awaitable``, with specified arguments.
|
|
|
|
Returns:
|
|
The result of type ``W`` of the function call. The result is owned by ``Await[W]``
|
|
and returned on all following ``_awaitable_wait`` calls.
|
|
|
|
|
|
``torch.jit._awaitable_nowait(W)``
|
|
Returns:
|
|
Trivial ``Await[W]`` with specified result.
|
|
|
|
|
|
Only in eager mode:
|
|
``fn() -> Callable[Tuple[Any], W]``
|
|
Returns:
|
|
Specified at ``_awaitable`` python function ``func``.
|
|
|
|
``args() -> Tuple[Any]``
|
|
Returns:
|
|
Specified at ``_awaitable`` python args.
|
|
|
|
``is_nowait() -> _bool``
|
|
Returns:
|
|
``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
|
|
|
|
In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
|
|
``_awaitable_wait()`` call will be transparently added.
|
|
"""
|
|
pass
|