mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
We want to make TorchRec sharded models TorchScriptable. TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212). In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W. At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization. To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`. In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`. Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`. Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`. The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called. For example (more examples in this PR `test/jit/test_await.py`) ``` def delayed(z: Tensor) -> Tensor: return Tensor * 3 @torch.jit.script def fn(x: Tensor): aw: Await[int] = torch.jit._awaitable(delayed, 99) a = torch.eye(2) b = torch.jit._awaitable_wait(aw) return a + b + x ``` Functions semantics: `_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]` Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call. `_awaitable_wait(Await[W]) -> W` Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first. `_awaitable_nowait(W) -> Await[W]` Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases. Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863 Approved by: https://github.com/davidberard98 |
||
|---|---|---|
| .. | ||
| api | ||
| backends | ||
| codegen | ||
| cuda | ||
| docs | ||
| frontend | ||
| ir | ||
| mobile | ||
| operator_upgraders | ||
| passes | ||
| python | ||
| runtime | ||
| serialization | ||
| tensorexpr | ||
| testing | ||
| jit_log.cpp | ||
| jit_log.h | ||
| jit_opt_limit.cpp | ||
| jit_opt_limit.h | ||
| JIT-AUTOCAST.md | ||
| OVERVIEW.md | ||
| README.md | ||
| resource_guard.h | ||
PyTorch JIT
This folder contains (most of) the C++ code for the PyTorch JIT, a language and compiler stack for executing PyTorch models portably and efficiently. To learn more about the JIT from a user perspective, please consult our reference documentation and tutorials.
A brief summary of the source tree:
- OVERVIEW.md: High-level technical overview of the JIT.
- frontend/: Taking PyTorch modules in Python and translating them into the JIT IR.
- ir/: Core IR abstractions.
- runtime/: Interpreter, graph execution, and JIT operators.
- codegen/: Generating efficient, hardware-specific code for JIT subgraphs.
- serialization/: Saving and loading modules.
- api/: Any user-facing C++ or Python interfaces.
- python/: Binding stuff into Python or accessing information from the Python environment.
- testing/: Utilities and helpers for testing.
- mobile/: Mobile-specific implementations of runtime components.
- passes/: IR-to-IR passes, generally for optimization and lowering.
- generated/: This folder is generated by the PyTorch build, and contains bindings for native PyTorch operators into the JIT.
Refer to each folder for more in-depth documentation.
Other relevant parts of the codebase not contained here:
- aten/src/ATen/core: contains JIT code re-used by other elements of the runtime system (eager, mobile, etc.)