mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
<!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at b5f48b6</samp> > _`torch.compile` docs_ > _Add a new section for `func`_ > _Winter of features_ Thanks @zou3519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101881 Approved by: https://github.com/eellison, https://github.com/zou3519
79 lines
2.6 KiB
ReStructuredText
79 lines
2.6 KiB
ReStructuredText
torch.func interaction with torch.compile
|
|
==============================================
|
|
|
|
So you want to use a `torch.func` ("functorch") transform (like `vmap`, `grad`, `jacrev`, etc) with `torch.compile`. Here's a guide to what works today, what doesn't, and how to work around it.
|
|
|
|
Applying a `torch.func` transform to a `torch.compile`'d function
|
|
-----------------------------------------------------------------
|
|
|
|
This doesn't work and is being tracked by `https://github.com/pytorch/pytorch/issues/100320`.
|
|
|
|
.. code:: python
|
|
|
|
import torch
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
def g(x):
|
|
return torch.grad(f)(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
g(x)
|
|
|
|
As a workaround, please put the `torch.compile` outside of the `torch.func` transform:
|
|
|
|
.. code:: python
|
|
|
|
import torch
|
|
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
@torch.compile
|
|
def g(x):
|
|
return torch.vmap(f)(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
g(x)
|
|
|
|
Doesn't work (PT 2.0): calling a `torch.func` transform inside of a `torch.compile`'ed function
|
|
------------------------------------------------------------------------------------------------
|
|
|
|
.. code:: python
|
|
|
|
import torch
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
return torch.vmap(torch.sum)(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
f(x)
|
|
|
|
This doesn't work yet. Please see the workaround (the next section).
|
|
|
|
Workaround: use `torch._dynamo.allow_in_graph`
|
|
----------------------------------------------
|
|
|
|
`allow_in_graph` is an escape hatch. If your code does not work with `torch.compile`, which introspects Python bytecode, but you believe it will work via a symbolic tracing approach (like `jax.jit`), then use `allow_in_graph`.
|
|
|
|
By using `allow_in_graph` to annotate a function, you promise PyTorch a couple of things that we are unable to completely verify:
|
|
- Your function is pure. That is, all outputs only depend on the inputs and do not depend on any captured Tensors.
|
|
- Your function is functional. That is, it does not mutate any state. This may be relaxed; we actually support functions that appear to be functional from the outside: they may have in-place PyTorch operations, but may not mutate global state or inputs to the function.
|
|
- Your function does not raise data-dependent errors.
|
|
|
|
.. code:: python
|
|
|
|
import torch
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
f(x)
|
|
|
|
A common pitfall is using `allow_in_graph` to annotate a function that invokes an `nn.Module`. This is because the outputs now depend on the parameters of the `nn.Module`. To actually get this to work, use `torch.func.functional_call` to extract the module state.
|