mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo, docs] programming model dynamo core concepts (#157985)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157985 Approved by: https://github.com/svekars, https://github.com/anijain2305
This commit is contained in:
parent
e469414b59
commit
433e43cbec
|
|
@ -50,8 +50,8 @@ IPython==8.12.0
|
|||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 0.13.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
|
|
@ -59,4 +59,3 @@ sphinx-copybutton==0.5.0
|
|||
sphinx-design==0.4.0
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-nb
|
||||
|
|
|
|||
BIN
docs/source/compile/_static/dynamo_summary_diagram.png
Normal file
BIN
docs/source/compile/_static/dynamo_summary_diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 424 KiB |
15
docs/source/compile/header_code.py
Normal file
15
docs/source/compile/header_code.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# to lower notebook execution time while hiding backend="eager"
|
||||
torch.compile = functools.partial(torch.compile, backend="eager")
|
||||
|
||||
# to clear torch logs format
|
||||
os.environ["TORCH_LOGS_FORMAT"] = ""
|
||||
torch._logging._internal.DEFAULT_FORMATTER = (
|
||||
torch._logging._internal._default_formatter()
|
||||
)
|
||||
torch._logging._internal._init_logs()
|
||||
164
docs/source/compile/programming_model.dynamo_core_concepts.md
Normal file
164
docs/source/compile/programming_model.dynamo_core_concepts.md
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
```
|
||||
|
||||
# Dynamo Core Concepts
|
||||
|
||||
**Summary:**
|
||||
|
||||
- Dynamo, `torch.compile`'s frontend, performs **tracing** to capture the semantics of a Python function
|
||||
(and its nested function calls) into a linear sequence of operations (the "(FX) graph"),
|
||||
residual bytecode, and "guards" (a list of conditions under which the graph and bytecode are valid).
|
||||
- Unsupported Python features lead to **graph breaks**, where Dynamo compiles a partial graph acquired from tracing,
|
||||
then runs the unsupported code, then resumes tracing.
|
||||
- Graph breaks may lead to slowness in torch.compile and prevent backend optimization opportunities.
|
||||
If you're not seeing the performance you expect, then check for graph breaks.
|
||||
|
||||
## Dynamo Tracing
|
||||
`torch.compile`'s frontend (Dynamo) is a custom Python bytecode interpreter designed to allow graph compilation
|
||||
in PyTorch programs while retaining the full flexibility of Python. Given a function to be compiled, Dynamo
|
||||
interprets Python bytecode to extract sequences of PyTorch operations into 1 or more FX graphs that may be further optimized by a backend.
|
||||
|
||||

|
||||
|
||||
For example, for the function `f` in the above diagram, Dynamo produces:
|
||||
- a single **FX graph** that takes in the original input plus some additional inputs required by the function.
|
||||
- **Python bytecode** that can be used as a drop-in replacement for `f`. In our example, the bytecode retrieves
|
||||
the additional inputs and passes it to the graph and also contains unoptimizable Python side effects (the list append)
|
||||
- **guards** that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,
|
||||
the graph produced by Dynamo specializes on the shapes of input Tensors.
|
||||
|
||||
## Graph Breaks
|
||||
Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch
|
||||
operators (FX graph). However, this is not always possible. When encountering code that can't be traced, a "**graph break**" occurs.
|
||||
In the default `torch.compile` settings, a graph break involves compiling the FX graph that has been determined so far,
|
||||
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
|
||||
|
||||
Graph breaks are a feature that allows Dynamo to run over arbitrary Python code and carve out functional subgraphs that can each be individually optimized.
|
||||
|
||||
However, it is possible for graph breaks to lead to unexpected slowness in `torch.compile`.
|
||||
If you're not getting the speedups you expect, we recommend checking for graph breaks and removing them.
|
||||
|
||||
Graph breaks may occur on things like:
|
||||
|
||||
- Data-dependent if-statements
|
||||
- Many Python built-in functions
|
||||
- C functions
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._logging.set_logs(graph_breaks=True)
|
||||
```
|
||||
|
||||
Below is an example of a graph break due to calling an unsupported operation `torch.save`:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def f(x):
|
||||
y = x ** 2 / 2
|
||||
torch.save(y, "foo.pt") # torch.save is an unsupported operation
|
||||
z = y ** 3 / 6
|
||||
return z
|
||||
|
||||
x = torch.randn(3)
|
||||
print(f(x))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import os
|
||||
os.remove("foo.pt")
|
||||
```
|
||||
|
||||
The semantics of `torch.compile(f)(x)` are roughly this:
|
||||
|
||||
```python
|
||||
def compiled_f_semantics(x):
|
||||
y = torch.compile(g, fullgraph=True)(x)
|
||||
torch.save(y, "foo.pt")
|
||||
z = torch.compile(h, fullgraph=True)(x)
|
||||
return z
|
||||
|
||||
def g(x):
|
||||
return x ** 2 / 2
|
||||
|
||||
def h(x):
|
||||
return y ** 3 / 6
|
||||
```
|
||||
|
||||
## Guards
|
||||
|
||||
`torch.compile` makes some assumptions about runtime values as we trace through code. During tracing, we generate "guards",
|
||||
which are runtime checks for these assumptions. Guards are run in future calls to the compiled function to determine if we
|
||||
can reuse previously compiled code. Examples of runtime checks are constant values, types, and object IDs.
|
||||
|
||||
Below is an example of generated guards. The `TENSOR_MATCH` guard checks for the input's type, device, dtype, shape, etc.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._logging.set_logs(guards=True)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
## Recompilations
|
||||
If the guards fail for every instance of previously compiled code, then `torch.compile` must "recompile" the function,
|
||||
requiring the original code to be traced again. In the example below, recompilation is necessary because the guard checking the tensor argument's shape failed.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._logging.set_logs(recompiles=True)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
print(fn(torch.ones(4, 4)))
|
||||
```
|
||||
|
||||
## Dynamic Shapes
|
||||
|
||||
`torch.compile` initially assumes tensor shapes are static/constant and guards based on these assumptions. By using "dynamic shapes,"
|
||||
we can get `torch.compile` to produce compiled code that can accept tensor inputs with different shapes - we avoid recompiling every time shapes differ.
|
||||
By default, automatic dynamic shapes are enabled in `torch.compile(dynamic=None)` - if compilation fails due to shape mismatch,
|
||||
recompilation is attempted with dynamic shapes. Dynamic shapes can also be fully enabled (`dynamic=True`) or disabled (`dynamic=False`).
|
||||
|
||||
Below, we enable dynamic shapes and note that we no longer need to recompile.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import logging
|
||||
torch._logging.set_logs(dynamic=logging.DEBUG, recompiles=True)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(dynamic=True)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
print(fn(torch.ones(4, 4)))
|
||||
```
|
||||
|
||||
For more information on dynamic shapes, see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
|
||||
11
docs/source/compile/programming_model.md
Normal file
11
docs/source/compile/programming_model.md
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# torch.compile Programming Model
|
||||
|
||||
The `torch.compile` programming model:
|
||||
1. Clarifies some internal behaviors of `torch.compile` so that one can better predict compiler behavior on user code and
|
||||
2. Provides ways for one to take more fine-grained control over `torch.compile`.
|
||||
|
||||
By understanding the `torch.compile` programming model, one can systematically unblock themselves when encountering issues with `torch.compile`.
|
||||
|
||||
```{toctree}
|
||||
programming_model.dynamo_core_concepts
|
||||
```
|
||||
|
|
@ -26,6 +26,9 @@ written in Python and it marks the transition of PyTorch from C++ to Python.
|
|||
which results in capturing the backwards pass "ahead-of-time". This enables
|
||||
acceleration of both forwards and backwards pass using TorchInductor.
|
||||
|
||||
To better understand how `torch.compile` tracing behavior on your code, or to
|
||||
learn more about the internals of `torch.compile`, please refer to the [`torch.compile` programming model](compile/programming_model.md).
|
||||
|
||||
:::{note}
|
||||
In some cases, the terms `torch.compile`, TorchDynamo, `torch.compiler`
|
||||
might be used interchangeably in this documentation.
|
||||
|
|
@ -98,6 +101,13 @@ Some of the most commonly used backends include:
|
|||
torch.compiler_inductor_provenance
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. toctree::
|
||||
:caption: `torch.compile` Programming Model
|
||||
|
||||
compile/programming_model
|
||||
```
|
||||
|
||||
% _If you want to contribute a developer-level topic
|
||||
% that provides in-depth overview of a torch._dynamo feature,
|
||||
% add in the below toc.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user