mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
76 lines
2.9 KiB
Markdown
76 lines
2.9 KiB
Markdown
# AOT Autograd - Introduction to an experimental compilation feature in Functorch
|
|
|
|
The primary compilation API we provide is something called AOTAutograd. AOT
|
|
Autograd is an experimental feature that allows ahead of time capture of forward
|
|
and backward graphs, and allows easy integration with compilers. This creates an
|
|
easy to hack Python-based development environment to speedup training of PyTorch
|
|
models. AOT Autograd currently lives inside functorch.compile namespace.
|
|
|
|
AOT Autograd is experimental and the APIs are likely to change. We are looking
|
|
for feedback. If you are interested in using AOT Autograd and need help or have
|
|
suggestions, please feel free to open an issue. We will be happy to help.
|
|
|
|
For example, here are some examples of how to use it.
|
|
```python
|
|
from functorch.compile import aot_function, aot_module, draw_graph
|
|
import torch.fx as fx
|
|
import torch
|
|
|
|
# This simply prints out the FX graph of the forwards and the backwards
|
|
def print_graph(name):
|
|
def f(fx_g: fx.GraphModule, inps):
|
|
print(name)
|
|
print(fx_g.code)
|
|
return fx_g
|
|
return f
|
|
|
|
def f(x):
|
|
return x.cos().cos()
|
|
|
|
nf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward"))
|
|
nf(torch.randn(3, requires_grad=True))
|
|
|
|
# You can do whatever you want before and after, and you can still backprop through the function.
|
|
inp = torch.randn(3, requires_grad=True)
|
|
inp = inp.cos()
|
|
out = nf(inp)
|
|
out = out.sin().sum().backward()
|
|
|
|
def f(x):
|
|
return x.cos().cos()
|
|
|
|
# This draws out the forwards and the backwards graphs as svg files
|
|
def graph_drawer(name):
|
|
def f(fx_g: fx.GraphModule, inps):
|
|
draw_graph(fx_g, name)
|
|
return fx_g
|
|
return f
|
|
|
|
aot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True))
|
|
|
|
# We also have a convenience API for applying AOTAutograd to modules
|
|
from torchvision.models import resnet18
|
|
aot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200))
|
|
# output elided since it's very long
|
|
|
|
# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler
|
|
|
|
def f(x):
|
|
return x.cos().cos()
|
|
|
|
def ts_compiler(fx_g: fx.GraphModule, inps):
|
|
f = torch.jit.script(fx_g)
|
|
print(f.graph)
|
|
f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training
|
|
return f
|
|
|
|
aot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True))
|
|
```
|
|
|
|
## Documentation
|
|
* AOT Autograd [documentation](https://pytorch.org/functorch/nightly/)
|
|
* Min-cut [recomputation](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) with AOT Autograd.
|
|
|
|
## Tutorials
|
|
You can use this [tutorial](https://pytorch.org/functorch/nightly/notebooks/aot_autograd_optimizations.html) to play with AOT Autograd.
|