pytorch/torch/csrc/jit
Ivan Kobzarev 2fc73622f8 [jit] Support Awaitable type (#90863)
We want to make TorchRec sharded models TorchScriptable.

TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212).
In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W.

At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization.

To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`.
In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`.

Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`.
Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`.

The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called.

For example (more examples in this PR `test/jit/test_await.py`)
```
      def delayed(z: Tensor) -> Tensor:
          return Tensor * 3

      @torch.jit.script
      def fn(x: Tensor):
          aw: Await[int] = torch.jit._awaitable(delayed, 99)
          a = torch.eye(2)
          b = torch.jit._awaitable_wait(aw)
          return a + b + x
```

Functions semantics:

`_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]`

Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call.

`_awaitable_wait(Await[W]) -> W`
Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first.

`_awaitable_nowait(W) -> Await[W]`

Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases.

Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863
Approved by: https://github.com/davidberard98
2023-01-30 17:38:59 +00:00
..
api [BE] Use nested namespaces in .cpp/.cu files (#92100) 2023-01-13 16:32:34 +00:00
backends Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
codegen Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
cuda [BC-Breaking] Separate stream_id, device_index, and device_type in pack and unpack for Streams (#81596) 2023-01-12 14:16:49 +00:00
docs Fix typo under torch directory (#87274) 2022-10-21 14:22:20 +00:00
frontend [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
ir [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
mobile Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
operator_upgraders Fix typos in .md and .rst files (#88962) 2022-11-17 03:37:02 +00:00
passes [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
python [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
runtime [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
serialization [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
tensorexpr Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
testing Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
jit_log.cpp Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
jit_log.h
jit_opt_limit.cpp Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
jit_opt_limit.h
JIT-AUTOCAST.md [JIT][Autocast] document that scripted autocast context cannot disable eager-enabled autocast (#81747) 2022-07-21 17:44:44 +00:00
OVERVIEW.md Fix typos under torch directory (#88172) 2022-11-01 22:58:22 +00:00
README.md
resource_guard.h

PyTorch JIT

This folder contains (most of) the C++ code for the PyTorch JIT, a language and compiler stack for executing PyTorch models portably and efficiently. To learn more about the JIT from a user perspective, please consult our reference documentation and tutorials.

A brief summary of the source tree:

  • OVERVIEW.md: High-level technical overview of the JIT.
  • frontend/: Taking PyTorch modules in Python and translating them into the JIT IR.
  • ir/: Core IR abstractions.
  • runtime/: Interpreter, graph execution, and JIT operators.
  • codegen/: Generating efficient, hardware-specific code for JIT subgraphs.
  • serialization/: Saving and loading modules.
  • api/: Any user-facing C++ or Python interfaces.
  • python/: Binding stuff into Python or accessing information from the Python environment.
  • testing/: Utilities and helpers for testing.
  • mobile/: Mobile-specific implementations of runtime components.
  • passes/: IR-to-IR passes, generally for optimization and lowering.
  • generated/: This folder is generated by the PyTorch build, and contains bindings for native PyTorch operators into the JIT.

Refer to each folder for more in-depth documentation.

Other relevant parts of the codebase not contained here:

  • aten/src/ATen/core: contains JIT code re-used by other elements of the runtime system (eager, mobile, etc.)