pytorch/docs/source/compile/programming_model.recompilation.md

162 lines
5.0 KiB
Markdown

---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(recompiles=True)
```
# Dealing with Recompilations
Recompilations are necessary for `torch.compile` soundness, but can result in significantly increased compile time.
Thus, minimizing recompilations while preserving soundness is essential for reducing compile time.
You can view recompilations and their reasons using tlparse or `TORCH_LOGS=recompiles`.
## Is Dynamic Shapes Enabled?
In the below example, we recompile due to mismatched shapes:
```{code-cell}
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3))
fn(torch.ones(4))
```
Make sure that the dynamic option of `torch.compile` is not set to `False`.
The default option, `dynamic=None`, will only attempt dynamic shapes after the first compilation.
You can set `dynamic=True` to upfront compile as dynamic as possible:
```{code-cell}
@torch.compile(dynamic=True)
def gn(x):
return x + 1
gn(torch.ones(3))
gn(torch.ones(4))
```
For more information on dynamic shapes, including dealing with errors/recompilations due to
dynamic shapes, see [the dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
## Wrapping Constants with Tensors
By default, `int` / `float` variables are treated as constants and are guarded on their exact value.
In the below example, we have a recompilation for each function call.
```{code-cell}
@torch.compile
def fn(x, c):
return x + c
for i in range(5):
fn(torch.ones(i), 0.5 + i)
```
In particular, for LR schedulers, initializing with a constant can lead to recompilations:
```{code-cell}
mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
@torch.compile
def gn(inp):
opt.zero_grad(True)
out = mod(inp).sum()
out.backward()
opt.step()
sched.step()
for i in range(5):
gn(torch.ones(3, 3))
```
In both examples, we can wrap `float` variables in tensors in order to prevent recompilations.
```{code-cell}
:tags: [remove-cell]
torch._dynamo.reset()
```
```{code-cell}
# first example
for i in range(5):
fn(torch.ones(i), torch.tensor(0.5 + i))
# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
for i in range(5):
gn(torch.ones(3, 3))
```
(programming_model.recompilation.changing_cache_size_limit)=
## Changing the Cache Size Limit
There is a limit to how many times a function can be recompiled,
determined by `torch._dynamo.config.cache_size_limit` and `torch._dynamo.config.accumulated_cache_size_limit`
(The exact difference between these 2 values is detailed in [`torch/_dynamo/cache_size.py`](https://github.com/pytorch/pytorch/blob/4ce6e6ec8890a3f6ee604c9efb3ff153825ce575/torch/_dynamo/cache_size.py#L14)).
If the Dynamo cache limit is hit, then all future compilation attempts **will result in the function being skipped (run eagerly)**.
Dynamo will still attempt to use previously compiled bytecode for future function calls, if the guards pass.
Note that in the case of a recompilation limit hit, **all nested function calls WILL be skipped**
(Dynamo will try to use previously compiled bytecode for the nested functions).
Dynamo will also issue a warning containing the affected function and which limit was hit.
In the example below, each function call results in a recompile attempt.
When we hit the cache size limit (by default, 8), we stop attempting to recompile.
(Note that we set `dynamic=False` for demonstration purposes to force recompilation every time).
```{code-cell}
@torch.compile(dynamic=False)
def fn(x):
return x + 1
for i in range(1, 10):
# recompile every time due to dynamic=False
fn(torch.ones(i))
```
If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit.
If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit.
```{code-cell}
torch._dynamo.config.cache_size_limit = 16
@torch.compile(dynamic=False)
def gn(x):
return x + 1
for i in range(1, 10):
gn(torch.ones(i))
```
## Graph Breaking to Reduce Recompilation Costs
If a large graph is recompiling and causing high compile time, you can intentionally introduce
a graph break in order to reduce recompilation costs, at the expense of introducing a performance hit.
```{code-cell}
def very_large_function(x):
return x + 1
@torch.compile(dynamic=False)
def fn(x, c):
y = very_large_function(x) # recompiled every time
return y + c
for i in range(1, 5):
fn(torch.ones(3), i)
@torch.compile(dynamic=False)
def gn(x, c):
y = very_large_function(x) # compiled only once
torch._dynamo.graph_break()
return y + c # recompiled every time
for i in range(1, 5):
gn(torch.ones(3), i)
```