mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159062 Approved by: https://github.com/svekars, https://github.com/zou3519, https://github.com/anijain2305
162 lines
5.0 KiB
Markdown
162 lines
5.0 KiB
Markdown
---
|
|
file_format: mystnb
|
|
kernelspec:
|
|
name: python3
|
|
mystnb:
|
|
execution_timeout: 30
|
|
execution_show_tb: True
|
|
merge_streams: True
|
|
---
|
|
|
|
```{code-cell}
|
|
:tags: [remove-cell]
|
|
import torch
|
|
|
|
import header_code
|
|
|
|
torch._logging.set_logs(recompiles=True)
|
|
```
|
|
|
|
# Dealing with Recompilations
|
|
|
|
Recompilations are necessary for `torch.compile` soundness, but can result in significantly increased compile time.
|
|
Thus, minimizing recompilations while preserving soundness is essential for reducing compile time.
|
|
|
|
You can view recompilations and their reasons using tlparse or `TORCH_LOGS=recompiles`.
|
|
|
|
## Is Dynamic Shapes Enabled?
|
|
|
|
In the below example, we recompile due to mismatched shapes:
|
|
|
|
```{code-cell}
|
|
@torch.compile
|
|
def fn(x):
|
|
return x + 1
|
|
fn(torch.ones(3))
|
|
fn(torch.ones(4))
|
|
```
|
|
|
|
Make sure that the dynamic option of `torch.compile` is not set to `False`.
|
|
The default option, `dynamic=None`, will only attempt dynamic shapes after the first compilation.
|
|
You can set `dynamic=True` to upfront compile as dynamic as possible:
|
|
|
|
```{code-cell}
|
|
@torch.compile(dynamic=True)
|
|
def gn(x):
|
|
return x + 1
|
|
gn(torch.ones(3))
|
|
gn(torch.ones(4))
|
|
```
|
|
|
|
For more information on dynamic shapes, including dealing with errors/recompilations due to
|
|
dynamic shapes, see [the dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
|
|
|
|
## Wrapping Constants with Tensors
|
|
By default, `int` / `float` variables are treated as constants and are guarded on their exact value.
|
|
In the below example, we have a recompilation for each function call.
|
|
|
|
```{code-cell}
|
|
@torch.compile
|
|
def fn(x, c):
|
|
return x + c
|
|
for i in range(5):
|
|
fn(torch.ones(i), 0.5 + i)
|
|
```
|
|
|
|
In particular, for LR schedulers, initializing with a constant can lead to recompilations:
|
|
|
|
```{code-cell}
|
|
mod = torch.nn.Linear(3, 3)
|
|
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
|
|
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
|
|
@torch.compile
|
|
def gn(inp):
|
|
opt.zero_grad(True)
|
|
out = mod(inp).sum()
|
|
out.backward()
|
|
opt.step()
|
|
sched.step()
|
|
for i in range(5):
|
|
gn(torch.ones(3, 3))
|
|
```
|
|
|
|
In both examples, we can wrap `float` variables in tensors in order to prevent recompilations.
|
|
|
|
```{code-cell}
|
|
:tags: [remove-cell]
|
|
torch._dynamo.reset()
|
|
```
|
|
|
|
```{code-cell}
|
|
# first example
|
|
for i in range(5):
|
|
fn(torch.ones(i), torch.tensor(0.5 + i))
|
|
# second example
|
|
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
|
|
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
|
|
for i in range(5):
|
|
gn(torch.ones(3, 3))
|
|
```
|
|
|
|
(programming_model.recompilation.changing_cache_size_limit)=
|
|
## Changing the Cache Size Limit
|
|
|
|
There is a limit to how many times a function can be recompiled,
|
|
determined by `torch._dynamo.config.cache_size_limit` and `torch._dynamo.config.accumulated_cache_size_limit`
|
|
(The exact difference between these 2 values is detailed in [`torch/_dynamo/cache_size.py`](https://github.com/pytorch/pytorch/blob/4ce6e6ec8890a3f6ee604c9efb3ff153825ce575/torch/_dynamo/cache_size.py#L14)).
|
|
If the Dynamo cache limit is hit, then all future compilation attempts **will result in the function being skipped (run eagerly)**.
|
|
Dynamo will still attempt to use previously compiled bytecode for future function calls, if the guards pass.
|
|
Note that in the case of a recompilation limit hit, **all nested function calls WILL be skipped**
|
|
(Dynamo will try to use previously compiled bytecode for the nested functions).
|
|
Dynamo will also issue a warning containing the affected function and which limit was hit.
|
|
In the example below, each function call results in a recompile attempt.
|
|
When we hit the cache size limit (by default, 8), we stop attempting to recompile.
|
|
(Note that we set `dynamic=False` for demonstration purposes to force recompilation every time).
|
|
|
|
```{code-cell}
|
|
@torch.compile(dynamic=False)
|
|
def fn(x):
|
|
return x + 1
|
|
for i in range(1, 10):
|
|
# recompile every time due to dynamic=False
|
|
fn(torch.ones(i))
|
|
```
|
|
|
|
If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit.
|
|
If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit.
|
|
|
|
```{code-cell}
|
|
torch._dynamo.config.cache_size_limit = 16
|
|
@torch.compile(dynamic=False)
|
|
def gn(x):
|
|
return x + 1
|
|
for i in range(1, 10):
|
|
gn(torch.ones(i))
|
|
```
|
|
|
|
## Graph Breaking to Reduce Recompilation Costs
|
|
If a large graph is recompiling and causing high compile time, you can intentionally introduce
|
|
a graph break in order to reduce recompilation costs, at the expense of introducing a performance hit.
|
|
|
|
```{code-cell}
|
|
def very_large_function(x):
|
|
return x + 1
|
|
|
|
@torch.compile(dynamic=False)
|
|
def fn(x, c):
|
|
y = very_large_function(x) # recompiled every time
|
|
return y + c
|
|
|
|
for i in range(1, 5):
|
|
fn(torch.ones(3), i)
|
|
|
|
@torch.compile(dynamic=False)
|
|
def gn(x, c):
|
|
y = very_large_function(x) # compiled only once
|
|
torch._dynamo.graph_break()
|
|
return y + c # recompiled every time
|
|
|
|
for i in range(1, 5):
|
|
gn(torch.ones(3), i)
|
|
```
|