mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo, docs] programming model dynamo core concepts (#157985)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157985 Approved by: https://github.com/svekars, https://github.com/anijain2305
This commit is contained in:
parent
e469414b59
commit
433e43cbec
|
|
@ -50,8 +50,8 @@ IPython==8.12.0
|
||||||
#Pinned versions: 8.12.0
|
#Pinned versions: 8.12.0
|
||||||
|
|
||||||
myst-nb==0.17.2
|
myst-nb==0.17.2
|
||||||
#Description: This is used to generate PyTorch functorch docs
|
#Description: This is used to generate PyTorch functorch and torch.compile docs
|
||||||
#Pinned versions: 0.13.2
|
#Pinned versions: 0.17.2
|
||||||
|
|
||||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||||
python-etcd==0.4.5
|
python-etcd==0.4.5
|
||||||
|
|
@ -59,4 +59,3 @@ sphinx-copybutton==0.5.0
|
||||||
sphinx-design==0.4.0
|
sphinx-design==0.4.0
|
||||||
sphinxcontrib-mermaid==1.0.0
|
sphinxcontrib-mermaid==1.0.0
|
||||||
myst-parser==0.18.1
|
myst-parser==0.18.1
|
||||||
myst-nb
|
|
||||||
|
|
|
||||||
BIN
docs/source/compile/_static/dynamo_summary_diagram.png
Normal file
BIN
docs/source/compile/_static/dynamo_summary_diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 424 KiB |
15
docs/source/compile/header_code.py
Normal file
15
docs/source/compile/header_code.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# to lower notebook execution time while hiding backend="eager"
|
||||||
|
torch.compile = functools.partial(torch.compile, backend="eager")
|
||||||
|
|
||||||
|
# to clear torch logs format
|
||||||
|
os.environ["TORCH_LOGS_FORMAT"] = ""
|
||||||
|
torch._logging._internal.DEFAULT_FORMATTER = (
|
||||||
|
torch._logging._internal._default_formatter()
|
||||||
|
)
|
||||||
|
torch._logging._internal._init_logs()
|
||||||
164
docs/source/compile/programming_model.dynamo_core_concepts.md
Normal file
164
docs/source/compile/programming_model.dynamo_core_concepts.md
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
---
|
||||||
|
file_format: mystnb
|
||||||
|
kernelspec:
|
||||||
|
name: python3
|
||||||
|
mystnb:
|
||||||
|
execution_timeout: 30
|
||||||
|
merge_streams: True
|
||||||
|
---
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
:tags: [remove-cell]
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import header_code
|
||||||
|
```
|
||||||
|
|
||||||
|
# Dynamo Core Concepts
|
||||||
|
|
||||||
|
**Summary:**
|
||||||
|
|
||||||
|
- Dynamo, `torch.compile`'s frontend, performs **tracing** to capture the semantics of a Python function
|
||||||
|
(and its nested function calls) into a linear sequence of operations (the "(FX) graph"),
|
||||||
|
residual bytecode, and "guards" (a list of conditions under which the graph and bytecode are valid).
|
||||||
|
- Unsupported Python features lead to **graph breaks**, where Dynamo compiles a partial graph acquired from tracing,
|
||||||
|
then runs the unsupported code, then resumes tracing.
|
||||||
|
- Graph breaks may lead to slowness in torch.compile and prevent backend optimization opportunities.
|
||||||
|
If you're not seeing the performance you expect, then check for graph breaks.
|
||||||
|
|
||||||
|
## Dynamo Tracing
|
||||||
|
`torch.compile`'s frontend (Dynamo) is a custom Python bytecode interpreter designed to allow graph compilation
|
||||||
|
in PyTorch programs while retaining the full flexibility of Python. Given a function to be compiled, Dynamo
|
||||||
|
interprets Python bytecode to extract sequences of PyTorch operations into 1 or more FX graphs that may be further optimized by a backend.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
For example, for the function `f` in the above diagram, Dynamo produces:
|
||||||
|
- a single **FX graph** that takes in the original input plus some additional inputs required by the function.
|
||||||
|
- **Python bytecode** that can be used as a drop-in replacement for `f`. In our example, the bytecode retrieves
|
||||||
|
the additional inputs and passes it to the graph and also contains unoptimizable Python side effects (the list append)
|
||||||
|
- **guards** that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,
|
||||||
|
the graph produced by Dynamo specializes on the shapes of input Tensors.
|
||||||
|
|
||||||
|
## Graph Breaks
|
||||||
|
Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch
|
||||||
|
operators (FX graph). However, this is not always possible. When encountering code that can't be traced, a "**graph break**" occurs.
|
||||||
|
In the default `torch.compile` settings, a graph break involves compiling the FX graph that has been determined so far,
|
||||||
|
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
|
||||||
|
|
||||||
|
Graph breaks are a feature that allows Dynamo to run over arbitrary Python code and carve out functional subgraphs that can each be individually optimized.
|
||||||
|
|
||||||
|
However, it is possible for graph breaks to lead to unexpected slowness in `torch.compile`.
|
||||||
|
If you're not getting the speedups you expect, we recommend checking for graph breaks and removing them.
|
||||||
|
|
||||||
|
Graph breaks may occur on things like:
|
||||||
|
|
||||||
|
- Data-dependent if-statements
|
||||||
|
- Many Python built-in functions
|
||||||
|
- C functions
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
:tags: [remove-cell]
|
||||||
|
torch._logging.set_logs(graph_breaks=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Below is an example of a graph break due to calling an unsupported operation `torch.save`:
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
@torch.compile
|
||||||
|
def f(x):
|
||||||
|
y = x ** 2 / 2
|
||||||
|
torch.save(y, "foo.pt") # torch.save is an unsupported operation
|
||||||
|
z = y ** 3 / 6
|
||||||
|
return z
|
||||||
|
|
||||||
|
x = torch.randn(3)
|
||||||
|
print(f(x))
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
:tags: [remove-cell]
|
||||||
|
import os
|
||||||
|
os.remove("foo.pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
The semantics of `torch.compile(f)(x)` are roughly this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compiled_f_semantics(x):
|
||||||
|
y = torch.compile(g, fullgraph=True)(x)
|
||||||
|
torch.save(y, "foo.pt")
|
||||||
|
z = torch.compile(h, fullgraph=True)(x)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
return x ** 2 / 2
|
||||||
|
|
||||||
|
def h(x):
|
||||||
|
return y ** 3 / 6
|
||||||
|
```
|
||||||
|
|
||||||
|
## Guards
|
||||||
|
|
||||||
|
`torch.compile` makes some assumptions about runtime values as we trace through code. During tracing, we generate "guards",
|
||||||
|
which are runtime checks for these assumptions. Guards are run in future calls to the compiled function to determine if we
|
||||||
|
can reuse previously compiled code. Examples of runtime checks are constant values, types, and object IDs.
|
||||||
|
|
||||||
|
Below is an example of generated guards. The `TENSOR_MATCH` guard checks for the input's type, device, dtype, shape, etc.
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
:tags: [remove-cell]
|
||||||
|
torch._logging.set_logs(guards=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
@torch.compile
|
||||||
|
def fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
print(fn(torch.ones(3, 3)))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recompilations
|
||||||
|
If the guards fail for every instance of previously compiled code, then `torch.compile` must "recompile" the function,
|
||||||
|
requiring the original code to be traced again. In the example below, recompilation is necessary because the guard checking the tensor argument's shape failed.
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
:tags: [remove-cell]
|
||||||
|
torch._logging.set_logs(recompiles=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
@torch.compile
|
||||||
|
def fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
print(fn(torch.ones(3, 3)))
|
||||||
|
print(fn(torch.ones(4, 4)))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dynamic Shapes
|
||||||
|
|
||||||
|
`torch.compile` initially assumes tensor shapes are static/constant and guards based on these assumptions. By using "dynamic shapes,"
|
||||||
|
we can get `torch.compile` to produce compiled code that can accept tensor inputs with different shapes - we avoid recompiling every time shapes differ.
|
||||||
|
By default, automatic dynamic shapes are enabled in `torch.compile(dynamic=None)` - if compilation fails due to shape mismatch,
|
||||||
|
recompilation is attempted with dynamic shapes. Dynamic shapes can also be fully enabled (`dynamic=True`) or disabled (`dynamic=False`).
|
||||||
|
|
||||||
|
Below, we enable dynamic shapes and note that we no longer need to recompile.
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
:tags: [remove-cell]
|
||||||
|
import logging
|
||||||
|
torch._logging.set_logs(dynamic=logging.DEBUG, recompiles=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
print(fn(torch.ones(3, 3)))
|
||||||
|
print(fn(torch.ones(4, 4)))
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on dynamic shapes, see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
|
||||||
11
docs/source/compile/programming_model.md
Normal file
11
docs/source/compile/programming_model.md
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
# torch.compile Programming Model
|
||||||
|
|
||||||
|
The `torch.compile` programming model:
|
||||||
|
1. Clarifies some internal behaviors of `torch.compile` so that one can better predict compiler behavior on user code and
|
||||||
|
2. Provides ways for one to take more fine-grained control over `torch.compile`.
|
||||||
|
|
||||||
|
By understanding the `torch.compile` programming model, one can systematically unblock themselves when encountering issues with `torch.compile`.
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
programming_model.dynamo_core_concepts
|
||||||
|
```
|
||||||
|
|
@ -26,6 +26,9 @@ written in Python and it marks the transition of PyTorch from C++ to Python.
|
||||||
which results in capturing the backwards pass "ahead-of-time". This enables
|
which results in capturing the backwards pass "ahead-of-time". This enables
|
||||||
acceleration of both forwards and backwards pass using TorchInductor.
|
acceleration of both forwards and backwards pass using TorchInductor.
|
||||||
|
|
||||||
|
To better understand how `torch.compile` tracing behavior on your code, or to
|
||||||
|
learn more about the internals of `torch.compile`, please refer to the [`torch.compile` programming model](compile/programming_model.md).
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
In some cases, the terms `torch.compile`, TorchDynamo, `torch.compiler`
|
In some cases, the terms `torch.compile`, TorchDynamo, `torch.compiler`
|
||||||
might be used interchangeably in this documentation.
|
might be used interchangeably in this documentation.
|
||||||
|
|
@ -98,6 +101,13 @@ Some of the most commonly used backends include:
|
||||||
torch.compiler_inductor_provenance
|
torch.compiler_inductor_provenance
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. toctree::
|
||||||
|
:caption: `torch.compile` Programming Model
|
||||||
|
|
||||||
|
compile/programming_model
|
||||||
|
```
|
||||||
|
|
||||||
% _If you want to contribute a developer-level topic
|
% _If you want to contribute a developer-level topic
|
||||||
% that provides in-depth overview of a torch._dynamo feature,
|
% that provides in-depth overview of a torch._dynamo feature,
|
||||||
% add in the below toc.
|
% add in the below toc.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user