mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Current torch.compile docs have become a bit of a mess with the docs expanded in the left nav. This PR moves them under the torch.compiler menu item in the left nav. A bunch of rewrites were made in collaboration with @msaroufim to address formatting issues, latest updates that moved some of the APIs to the public torch.compiler namespace were addressed as well. The documentation is broken down in three categories that address three main audiences: PyTorch users, Pytorch Developers and PyTorch backend vendors. While, the user-facing documentation was significantly rewritten, dev docs and vendor docs kept mostly untouched. This can be addressed in the follow up PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105376 Approved by: https://github.com/msaroufim
496 lines
19 KiB
ReStructuredText
496 lines
19 KiB
ReStructuredText
Frequently Asked Questions
|
||
==========================
|
||
**Author**: `Mark Saroufim <https://github.com/msaroufim>`_
|
||
|
||
Does ``torch.compile`` support training?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
``torch.compile`` supports training, using AOTAutograd to capture backwards:
|
||
|
||
1. The ``.forward()`` graph and ``optimizer.step()`` is captured by
|
||
TorchDynamo’s python ``evalframe`` frontend.
|
||
2. For each segment of ``.forward()`` that torchdynamo captures, it uses
|
||
AOTAutograd to generate a backward graph segment.
|
||
3. Each pair of forward and backward graph are (optionally) min-cut
|
||
partitioned to save the minimal state between forward and backward.
|
||
4. The forward and backward pairs are wrapped in ``autograd.function`` modules.
|
||
5. Usercode calling\ ``.backward()`` still triggers eager’s autograd engine,
|
||
which runs each *compiled backward* graph as if it were one op, also running
|
||
any non-compiled eager ops’ ``.backward()`` functions.
|
||
|
||
Do you support Distributed code?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
``torch.compile`` supports ``DistributedDataParallel`` (DDP).
|
||
Support for other distributed training libraries is being considered.
|
||
|
||
The main reason why Distributed code is challenging with dynamo is
|
||
because AOTAutograd unrolls both the forward and backward pass and
|
||
provides 2 graphs for backends to optimize. This is a problem for
|
||
distributed code because we’d like to ideally overlap communication
|
||
operations with computations. Eager pytorch accomplishes this in
|
||
different ways for DDP/FSDP- using autograd hooks, module hooks, and
|
||
modifications/mutations of module states. In a naive application of
|
||
dynamo, hooks that should run directly after an operation during
|
||
backwards may be delayed until after the entire compiled region of
|
||
backwards ops, due to how AOTAutograd compiled functions interact with
|
||
dispatcher hooks.
|
||
|
||
The basic strategy for optimizing DDP with Dynamo is outlined in
|
||
`distributed.py <https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/optimizations/distributed.py>`__
|
||
where the main idea will be to graph break on `DDP bucket
|
||
boundaries <https://pytorch.org/docs/stable/notes/ddp.html#internal-design>`__.
|
||
|
||
When each node in DDP needs to synchronize its weights with the other
|
||
nodes it organizes its gradients and parameters into buckets which
|
||
reduces communication times and allows a node to broadcast a fraction of
|
||
its gradients to other waiting nodes.
|
||
|
||
Graph breaks in distributed code mean you can expect dynamo and its
|
||
backends to optimize the compute overhead of a distributed program but
|
||
not its communication overhead. Graph-breaks may interfere with
|
||
compilation speedups, if the reduced graph-size robs the compiler of
|
||
fusion opportunities. However, there are diminishing returns with
|
||
increasing graph size since most of the current compute optimizations
|
||
are local fusions. So in practice this approach may be sufficient.
|
||
|
||
Do I still need to export whole graphs?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
For the vast majority of models you probably don’t and you can use
|
||
``torch.compile()`` as is but there are a few situations where
|
||
full graphs are necessary and you can can ensure a full graph by simply
|
||
running ``torch.compile(..., nopython=True)``. These situations include:
|
||
|
||
* Large scale training runs, such as $250K+ that require pipeline parallelism
|
||
and other advanced sharding strategies.
|
||
|
||
* Inference optimizers like `TensorRT <https://github.com/pytorch/TensorRT>`__
|
||
or `AITemplate <https://github.com/facebookincubator/AITemplate>`__ that
|
||
rely on fusing much more aggressively than training optimizers.
|
||
|
||
* Mobile training or inference.
|
||
|
||
Future work will include tracing communication operations into graphs,
|
||
coordinating these operations with compute optimizations, and optimizing
|
||
the communication operations.
|
||
|
||
Why is my code crashing?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
If your code ran just fine without ``torch.compile`` and started to
|
||
crash with it is enabled, then the most important first step is figuring
|
||
out which part of the stack your failure occurred. To troubleshoot that,
|
||
follow the steps below and only try the next step if the previous one
|
||
succeeded.
|
||
|
||
1. ``torch.compile(..., backend="eager")`` which only runs TorchDynamo
|
||
forward graph capture and then runs the captured graph with PyTorch.
|
||
If this fails then there’s an issue with TorchDynamo.
|
||
|
||
2. ``torch.compile(..., backend="aot_eager")``
|
||
which runs TorchDynamo to capture a forward graph, and then AOTAutograd
|
||
to trace the backward graph without any additional backend compiler
|
||
steps. PyTorch eager will then be used to run the forward and backward
|
||
graphs. If this fails then there’s an issue with AOTAutograd.
|
||
|
||
3. ``torch.compile(..., backend="inductor")`` which runs TorchDynamo to capture a
|
||
forward graph, and then AOTAutograd to trace the backward graph with the
|
||
TorchInductor compiler. If this fails then there’s an issue with TorchInductor
|
||
|
||
Why is compilation slow?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
* **Dynamo Compilation**– TorchDynamo has a builtin stats function for
|
||
collecting and displaying the time spent in each compilation phase.
|
||
These stats can be accessed by calling ``torch._dynamo.utils.compile_times()``
|
||
after executing ``torch._dynamo``. By default, this returns a string
|
||
representation of the compile times spent in each TorchDynamo function by name.
|
||
|
||
* **Inductor Compilation**– TorchInductor has a builtin stats and trace function
|
||
for displaying time spent in each compilation phase, output code, output
|
||
graph visualization and IR dump. ``env TORCH_COMPILE_DEBUG=1 python repro.py``.
|
||
This is a debugging tool designed to make it easier to debug/understand the
|
||
internals of TorchInductor with an output that will look something like
|
||
`this <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
|
||
Each file in that debug trace can be enabled/disabled via
|
||
``torch._inductor.config.trace.*``. The profile and the diagram are both
|
||
disabled by default since they are expensive to generate. See the
|
||
`example debug directory
|
||
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
|
||
for more examples.
|
||
|
||
* **Excessive Recompilation**
|
||
When TorchDynamo compiles a function (or part of one), it makes certain
|
||
assumptions about locals and globals in order to allow compiler
|
||
optimizations, and expresses these assumptions as guards that check
|
||
particular values at runtime. If any of these guards fail, Dynamo will
|
||
recompile that function (or part) up to
|
||
``torch._dynamo.config.cache_size_limit`` times. If your program is
|
||
hitting the cache limit, you will first need to determine which guard is
|
||
failing and what part of your program is triggering it. The
|
||
`recompilation profiler <#recompilation-profiler>`__ automates the
|
||
process of setting TorchDynamo’s cache limit to 1 and running your
|
||
program under an observation-only ‘compiler’ that records the causes of
|
||
any guard failures. You should be sure to run your program for at least
|
||
as long (as many iterations) as you were running when you ran into
|
||
trouble, and the profiler will accumulate statistics over this duration.
|
||
|
||
.. code-block:: python
|
||
|
||
from torch._dynamo.utils import CompileProfiler
|
||
|
||
def my_model():
|
||
...
|
||
|
||
with CompileProfiler() as prof:
|
||
profiler_model = torch.compile(my_model, backend=prof)
|
||
profiler_model()
|
||
print(prof.report())
|
||
|
||
Why are you recompiling in production?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
In some cases, you may not want unexpected compiles after a program has
|
||
warmed up. For example, if you are serving production traffic in a
|
||
latency critical application. For this, TorchDynamo provides an
|
||
alternate mode where prior compiled graphs are used, but no new ones are
|
||
generated:
|
||
|
||
.. code-block:: python
|
||
|
||
frozen_toy_example = dynamo.run(toy_example)
|
||
frozen_toy_example(torch.randn(10), torch.randn(10))
|
||
|
||
How are you speeding up my code?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
There are 3 major ways to accelerate PyTorch code:
|
||
|
||
1. Kernel fusion via vertical fusions which fuse sequential operations to avoid
|
||
excessive read/writes. For example, fuse 2 subsequent cosines means you
|
||
can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion:
|
||
the simplest example being batching where a single matrix is multiplied
|
||
with a batch of examples but the more general scenario is a grouped GEMM
|
||
where a group of matrix multiplications are scheduled together
|
||
|
||
2. Out of order execution: A general optimization for compilers, by looking ahead
|
||
at the exact data dependencies within a graph we can decide on the most
|
||
opportune time to execute a node and which buffers can be reused
|
||
|
||
3. Automatic work placement: Similar of the out of order execution point,
|
||
but by matching nodes of a graph to resources like physical hardware or
|
||
memory we can design an appropriate schedule
|
||
|
||
The above are general principles for accelerating PyTorch code but
|
||
different backends will each make different tradeoffs on what to
|
||
optimize. For example Inductor first takes care of fusing whatever it
|
||
can and only then generates `Triton <https://openai.com/blog/triton/>`__
|
||
kernels. It can also
|
||
|
||
Triton in addition offers speedups because of automatic memory
|
||
coalescing, memory management and scheduling within each Streaming
|
||
Multiprocessor and has been designed to handle tiled computations.
|
||
|
||
However, regardless of the backend you use it’s best to use a benchmark
|
||
and see approach so try out the PyTorch profiler, visually inspect the
|
||
generated kernels and try to see what’s going on for yourself.
|
||
|
||
Why am I not seeing speedups?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Graph Breaks
|
||
------------
|
||
|
||
The main reason you won’t see the speedups you’d like to by using dynamo
|
||
is excessive graph breaks. So what’s a graph break?
|
||
|
||
Given a program like:
|
||
|
||
.. code-block:: python
|
||
|
||
def some_fun(x):
|
||
...
|
||
|
||
torch.compile(some_fun)(x)
|
||
...
|
||
|
||
Torchdynamo will attempt to compile all of the torch/tensor operations
|
||
within ``some_fun()`` into a single FX graph, but it may fail to capture
|
||
everything into one graph.
|
||
|
||
Some graph break reasons are insurmountable to TorchDynamo like calling
|
||
into a C extension other than PyTorch is invisible to TorchDynamo, and
|
||
could do arbitrary things without TorchDynamo being able to introduce
|
||
necessary guards to ensure that the compiled program would be safe to reuse.
|
||
|
||
To maximize performance, it’s important to have as few graph breaks
|
||
as possible.
|
||
|
||
Identifying the cause of a graph break
|
||
--------------------------------------
|
||
|
||
To identify all graph breaks in a program and the associated reasons for
|
||
the breaks, ``torch._dynamo.explain`` can be used. This tool runs
|
||
TorchDynamo on the supplied function and aggregates the graph breaks
|
||
that are encountered. Here is an example usage:
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
import torch._dynamo as dynamo
|
||
def toy_example(a, b):
|
||
x = a / (torch.abs(a) + 1)
|
||
print("woo")
|
||
if b.sum() < 0:
|
||
b = b * -1
|
||
return x * b
|
||
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
|
||
print(explanation)
|
||
"""
|
||
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
|
||
Break reasons:
|
||
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
|
||
File "t2.py", line 16, in toy_example
|
||
print("woo")
|
||
|
||
2. generic_jump
|
||
File "t2.py", line 17, in toy_example
|
||
if b.sum() < 0:
|
||
"""
|
||
|
||
To throw an error on the first graph break encountered you can use
|
||
disable python fallback by using ``nopython=True``, this should be
|
||
familiar if you’ve worked with export based compilers.
|
||
|
||
.. code-block:: python
|
||
|
||
def toy_example(a, b):
|
||
...
|
||
|
||
torch.compile(toy_example, fullgraph=True, backend=<compiler>)
|
||
|
||
Why didn’t my code recompile when I changed it?
|
||
-----------------------------------------------
|
||
|
||
If you enabled dynamic shapes by setting
|
||
``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code
|
||
won’t recompile on shape changes. We’ve added support for dynamic shapes
|
||
which avoids recompilations in the case when shapes vary by less than a
|
||
factor of 2. This is especially useful in scenarios like varying image
|
||
sizes in CV or variable sequence length in NLP. In inference scenarios
|
||
it’s often not possible to know what a batch size will be beforehand
|
||
because you take what you can get from different client apps.
|
||
|
||
In general, TorchDynamo tries very hard not to recompile things
|
||
unnecessarily so if for example TorchDynamo finds 3 graphs and your
|
||
change only modified one graph then only that graph will recompile. So
|
||
another tip to avoid potentially slow compilation times is to warmup a
|
||
model by compiling it once after which subsequent compilations will be
|
||
much faster. Cold start compile times is still a metric we track
|
||
visibly.
|
||
|
||
Why am I getting incorrect results?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Accuracy issues can also be minified if you set the environment variable
|
||
``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect
|
||
model and a full repro might be something like
|
||
``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason
|
||
we need this is downstream compilers will codegen code whether it’s
|
||
Triton code or the C++ backend, the numerics from those downstream
|
||
compilers can be different in subtle ways yet have dramatic impact on
|
||
your training stability. So the accuracy debugger is very useful for us
|
||
to detect bugs in our codegen or with a backend compiler.
|
||
|
||
If you'd like to ensure that random number generation is the same across both torch
|
||
and triton then you can enable ``torch._inductor.config.fallback_random = True``
|
||
|
||
Why am I getting OOMs?
|
||
~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Dynamo is still an alpha product so there’s a few sources of OOMs and if
|
||
you’re seeing an OOM try disabling the following configurations in this
|
||
order and then open an issue on GitHub so we can solve the root problem
|
||
1. If you’re using dynamic shapes try disabling them, we’ve disabled
|
||
them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
|
||
CUDA graphs with Triton are enabled by default in inductor but removing
|
||
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.
|
||
|
||
``torch.func`` does not work with ``torch.compile``
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Applying a ``torch.func`` transform to a function that uses ``torch.compile``
|
||
does not work:
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
|
||
@torch.compile
|
||
def f(x):
|
||
return torch.sin(x)
|
||
|
||
def g(x):
|
||
return torch.grad(f)(x)
|
||
|
||
x = torch.randn(2, 3)
|
||
g(x)
|
||
|
||
As a workaround, use ``torch.compile`` outside of the ``torch.func`` function:
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
|
||
def f(x):
|
||
return torch.sin(x)
|
||
|
||
@torch.compile
|
||
def g(x):
|
||
return torch.vmap(f)(x)
|
||
|
||
x = torch.randn(2, 3)
|
||
g(x)
|
||
|
||
Applying a ``torch.func`` transform to a function handled with ``torch.compile``
|
||
--------------------------------------------------------------------------------
|
||
|
||
For example, you have the following code:
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
|
||
@torch.compile
|
||
def f(x):
|
||
return torch.sin(x)
|
||
|
||
def g(x):
|
||
return torch.grad(f)(x)
|
||
|
||
x = torch.randn(2, 3)
|
||
g(x)
|
||
|
||
This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320>`__
|
||
that you can track for this.
|
||
As a workaround, please put the ``torch.compile`` outside of ``torch.func`` transform:
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
|
||
def f(x):
|
||
return torch.sin(x)
|
||
|
||
@torch.compile
|
||
def g(x):
|
||
return torch.vmap(f)(x)
|
||
|
||
x = torch.randn(2, 3)
|
||
g(x)
|
||
|
||
Calling ``torch.func`` transform inside of a function handled with ``torch.compile``
|
||
------------------------------------------------------------------------------------
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
|
||
@torch.compile
|
||
def f(x):
|
||
return torch.vmap(torch.sum)(x)
|
||
|
||
x = torch.randn(2, 3)
|
||
f(x)
|
||
|
||
This doesn't work yet. As a workaround, use ``torch._dynamo.allow_in_graph``
|
||
|
||
``allow_in_graph`` is an escape hatch. If your code does not work with
|
||
``torch.compile``, which introspects Python bytecode, but you believe it
|
||
will work via a symbolic tracing approach (like ``jax.jit``), then use
|
||
``allow_in_graph``.
|
||
|
||
By using ``allow_in_graph`` to annotate a function, you must make sure
|
||
your code meets the following requirements:
|
||
|
||
- All outputs in your function only depend on the inputs and
|
||
do not depend on any captured Tensors.
|
||
- Your function is functional. That is, it does not mutate any state. This may
|
||
be relaxed; we actually support functions that appear to be functional from
|
||
the outside: they may have in-place PyTorch operations, but may not mutate
|
||
global state or inputs to the function.
|
||
- Your function does not raise data-dependent errors.
|
||
|
||
.. code-block:: python
|
||
|
||
import torch
|
||
|
||
@torch.compile
|
||
def f(x):
|
||
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
|
||
|
||
x = torch.randn(2, 3)
|
||
f(x)
|
||
|
||
A common pitfall is using ``allow_in_graph`` to annotate a function that
|
||
invokes an ``nn.Module``. This is because the outputs now depend on the
|
||
parameters of the ``nn.Module``. To get this to work, use
|
||
``torch.func.functional_call`` to extract the module state.
|
||
|
||
Which API to use for fine grain tracing?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
In some cases, you might need to exclude small parts of your code from the
|
||
torch.compile compilations. This section provides some of the answers and
|
||
you can find more information in :ref:`torchdynamo_fine_grain_tracing`.
|
||
|
||
How do I graph break on a function?
|
||
-----------------------------------
|
||
|
||
Graph break on a function is not enough to sufficiently express what you want
|
||
PyTorch to do. You need to be more specific about your use case. Some of the
|
||
most common use cases you might want to consider:
|
||
|
||
* If you want to disable compilation on this function frame and the recursively
|
||
invoked frames, use ``torch._dynamo.disable``.
|
||
|
||
* If you want a particular operator, such as ``fbgemm`` to use the eager mode,
|
||
use ``torch._dynamo.disallow_in_graph``.
|
||
|
||
Some of the uncommon use cases include:
|
||
|
||
* If you want to disable TorchDynamo on the function frame but enable it back
|
||
on the recursively invoked frames – use ``torch._dynamo.disable(recursive=False)``.
|
||
|
||
* If you want to prevent inlining of a function frame – use ``torch._dynamo.graph_break``
|
||
at the beginning of the function you want to prevent inlining.
|
||
|
||
What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo.disallow_in_graph``
|
||
-----------------------------------------------------------------------------------------------
|
||
|
||
Disallow-in-graph works at the level of operators, or more specifically,
|
||
the operators that you see in the TorchDynamo extracted graphs.
|
||
|
||
Disable works at the function frame level and decides if TorchDynamo
|
||
should look into the function frame or not.
|
||
|
||
What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo_skip``
|
||
----------------------------------------------------------------------------------
|
||
|
||
.. note::
|
||
``torch._dynamo_skip`` is deprecated.
|
||
|
||
You most likely need ``torch._dynamo.disable``. But in an unlikely scenario, you
|
||
might need even finer control. Suppose you want to disable the tracing on just
|
||
the ``a_fn`` function, but want to continue the tracing back in ``aa_fn`` and
|
||
``ab_fn``. The image below demonstrates this use case:
|
||
|
||
|
||
.. figure:: _static/img/fine_grained_apis/call_stack_diagram.png
|
||
:alt: diagram of torch.compile + disable(a_fn, recursive=False)
|
||
|
||
In this case, you can use ``torch._dynamo.disable(recursive=False)``.
|
||
In previous versions, this functionality was provided by ``torch._dynamo.skip``.
|
||
This is now supported by the ``recursive`` flag inside ``torch._dynamo.disable``.
|