mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #155021 Fixes #155158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369 Approved by: https://github.com/svekars, https://github.com/malfet
202 lines
8.5 KiB
Markdown
202 lines
8.5 KiB
Markdown
# Migrating from functorch to torch.func
|
|
|
|
torch.func, previously known as "functorch", is
|
|
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
|
|
|
|
functorch started as an out-of-tree library over at
|
|
the [pytorch/functorch](https://github.com/pytorch/functorch) repository.
|
|
Our goal has always been to upstream functorch directly into PyTorch and provide
|
|
it as a core PyTorch library.
|
|
|
|
As the final step of the upstream, we've decided to migrate from being a top level package
|
|
(`functorch`) to being a part of PyTorch to reflect how the function transforms are
|
|
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
|
|
`import functorch` and ask that users migrate to the newest APIs, which we
|
|
will maintain going forward. `import functorch` will be kept around to maintain
|
|
backwards compatibility for a couple of releases.
|
|
|
|
## function transforms
|
|
|
|
The following APIs are a drop-in replacement for the following
|
|
[functorch APIs](https://pytorch.org/functorch/1.13/functorch.html).
|
|
They are fully backwards compatible.
|
|
|
|
| functorch API | PyTorch API (as of PyTorch 2.0) |
|
|
| ----------------------------------- | ---------------------------------------------- |
|
|
| functorch.vmap | {func}`torch.vmap` or {func}`torch.func.vmap` |
|
|
| functorch.grad | {func}`torch.func.grad` |
|
|
| functorch.vjp | {func}`torch.func.vjp` |
|
|
| functorch.jvp | {func}`torch.func.jvp` |
|
|
| functorch.jacrev | {func}`torch.func.jacrev` |
|
|
| functorch.jacfwd | {func}`torch.func.jacfwd` |
|
|
| functorch.hessian | {func}`torch.func.hessian` |
|
|
| functorch.functionalize | {func}`torch.func.functionalize` |
|
|
|
|
Furthermore, if you are using torch.autograd.functional APIs, please try out
|
|
the {mod}`torch.func` equivalents instead. {mod}`torch.func` function
|
|
transforms are more composable and more performant in many cases.
|
|
|
|
| torch.autograd.functional API | torch.func API (as of PyTorch 2.0) |
|
|
| ------------------------------------------- | ---------------------------------------------- |
|
|
| {func}`torch.autograd.functional.vjp` | {func}`torch.func.grad` or {func}`torch.func.vjp` |
|
|
| {func}`torch.autograd.functional.jvp` | {func}`torch.func.jvp` |
|
|
| {func}`torch.autograd.functional.jacobian` | {func}`torch.func.jacrev` or {func}`torch.func.jacfwd` |
|
|
| {func}`torch.autograd.functional.hessian` | {func}`torch.func.hessian` |
|
|
|
|
## NN module utilities
|
|
|
|
We've changed the APIs to apply function transforms over NN modules to make them
|
|
fit better into the PyTorch design philosophy. The new API is different, so
|
|
please read this section carefully.
|
|
|
|
### functorch.make_functional
|
|
|
|
{func}`torch.func.functional_call` is the replacement for
|
|
[functorch.make_functional](https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional)
|
|
and
|
|
[functorch.make_functional_with_buffers](https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers).
|
|
However, it is not a drop-in replacement.
|
|
|
|
If you're in a hurry, you can use
|
|
[helper functions in this gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf)
|
|
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
|
|
We recommend using {func}`torch.func.functional_call` directly because it is a more explicit
|
|
and flexible API.
|
|
|
|
Concretely, functorch.make_functional returns a functional module and parameters.
|
|
The functional module accepts parameters and inputs to the model as arguments.
|
|
{func}`torch.func.functional_call` allows one to call the forward pass of an existing
|
|
module using new parameters and buffers and inputs.
|
|
|
|
Here's an example of how to compute gradients of parameters of a model using functorch
|
|
vs {mod}`torch.func`:
|
|
|
|
```python
|
|
# ---------------
|
|
# using functorch
|
|
# ---------------
|
|
import torch
|
|
import functorch
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
fmodel, params = functorch.make_functional(model)
|
|
|
|
def compute_loss(params, inputs, targets):
|
|
prediction = fmodel(params, inputs)
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
grads = functorch.grad(compute_loss)(params, inputs, targets)
|
|
|
|
# ------------------------------------
|
|
# using torch.func (as of PyTorch 2.0)
|
|
# ------------------------------------
|
|
import torch
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
params = dict(model.named_parameters())
|
|
|
|
def compute_loss(params, inputs, targets):
|
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
grads = torch.func.grad(compute_loss)(params, inputs, targets)
|
|
```
|
|
|
|
And here's an example of how to compute jacobians of model parameters:
|
|
|
|
```python
|
|
# ---------------
|
|
# using functorch
|
|
# ---------------
|
|
import torch
|
|
import functorch
|
|
inputs = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
fmodel, params = functorch.make_functional(model)
|
|
jacobians = functorch.jacrev(fmodel)(params, inputs)
|
|
|
|
# ------------------------------------
|
|
# using torch.func (as of PyTorch 2.0)
|
|
# ------------------------------------
|
|
import torch
|
|
from torch.func import jacrev, functional_call
|
|
inputs = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
params = dict(model.named_parameters())
|
|
# jacrev computes jacobians of argnums=0 by default.
|
|
# We set it to 1 to compute jacobians of params
|
|
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
|
|
```
|
|
|
|
Note that it is important for memory consumption that you should only carry
|
|
around a single copy of your parameters. `model.named_parameters()` does not copy
|
|
the parameters. If in your model training you update the parameters of the model
|
|
in-place, then the `nn.Module` that is your model has the single copy of the
|
|
parameters and everything is OK.
|
|
|
|
However, if you want to carry your parameters around in a dictionary and update
|
|
them out-of-place, then there are two copies of parameters: the one in the
|
|
dictionary and the one in the `model`. In this case, you should change
|
|
`model` to not hold memory by converting it to the meta device via
|
|
`model.to('meta')`.
|
|
|
|
### functorch.combine_state_for_ensemble
|
|
|
|
Please use {func}`torch.func.stack_module_state` instead of
|
|
[functorch.combine_state_for_ensemble](https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html)
|
|
{func}`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
|
|
one of stacked buffers, that can then be used with {func}`torch.vmap` and {func}`torch.func.functional_call`
|
|
for ensembling.
|
|
|
|
For example, here is an example of how to ensemble over a very simple model:
|
|
|
|
```python
|
|
import torch
|
|
num_models = 5
|
|
batch_size = 64
|
|
in_features, out_features = 3, 3
|
|
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
|
data = torch.randn(batch_size, 3)
|
|
|
|
# ---------------
|
|
# using functorch
|
|
# ---------------
|
|
import functorch
|
|
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
|
|
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
|
|
assert output.shape == (num_models, batch_size, out_features)
|
|
|
|
# ------------------------------------
|
|
# using torch.func (as of PyTorch 2.0)
|
|
# ------------------------------------
|
|
import copy
|
|
|
|
# Construct a version of the model with no memory by putting the Tensors on
|
|
# the meta device.
|
|
base_model = copy.deepcopy(models[0])
|
|
base_model.to('meta')
|
|
|
|
params, buffers = torch.func.stack_module_state(models)
|
|
|
|
# It is possible to vmap directly over torch.func.functional_call,
|
|
# but wrapping it in a function makes it clearer what is going on.
|
|
def call_single_model(params, buffers, data):
|
|
return torch.func.functional_call(base_model, (params, buffers), (data,))
|
|
|
|
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
|
|
assert output.shape == (num_models, batch_size, out_features)
|
|
```
|
|
|
|
## functorch.compile
|
|
|
|
We are no longer supporting functorch.compile (also known as AOTAutograd)
|
|
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
|
|
into PyTorch's compilation story. If you are a user, please use
|
|
{func}`torch.compile` instead.
|