mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #155021 Fixes #155158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369 Approved by: https://github.com/svekars, https://github.com/malfet
57 lines
2.1 KiB
Markdown
57 lines
2.1 KiB
Markdown
# torch.func
|
|
|
|
```{eval-rst}
|
|
.. currentmodule:: torch.func
|
|
```
|
|
|
|
torch.func, previously known as "functorch", is
|
|
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
|
|
|
|
```{note}
|
|
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
|
|
What this means is that the features generally work (unless otherwise documented)
|
|
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
|
may change under user feedback and we don't have full coverage over PyTorch operations.
|
|
|
|
If you have suggestions on the API or use-cases you'd like to be covered, please
|
|
open a GitHub issue or reach out. We'd love to hear about how you're using the library.
|
|
```
|
|
|
|
## What are composable function transforms?
|
|
|
|
- A "function transform" is a higher-order function that accepts a numerical function
|
|
and returns a new function that computes a different quantity.
|
|
|
|
- {mod}`torch.func` has auto-differentiation transforms (`grad(f)` returns a function that
|
|
computes the gradient of `f`), a vectorization/batching transform (`vmap(f)`
|
|
returns a function that computes `f` over batches of inputs), and others.
|
|
|
|
- These function transforms can compose with each other arbitrarily. For example,
|
|
composing `vmap(grad(f))` computes a quantity called per-sample-gradients that
|
|
stock PyTorch cannot efficiently compute today.
|
|
|
|
## Why composable function transforms?
|
|
|
|
There are a number of use cases that are tricky to do in PyTorch today:
|
|
|
|
- computing per-sample-gradients (or other per-sample quantities)
|
|
- running ensembles of models on a single machine
|
|
- efficiently batching together tasks in the inner-loop of MAML
|
|
- efficiently computing Jacobians and Hessians
|
|
- efficiently computing batched Jacobians and Hessians
|
|
|
|
Composing {func}`vmap`, {func}`grad`, and {func}`vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
|
This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).
|
|
|
|
## Read More
|
|
|
|
```{eval-rst}
|
|
.. toctree::
|
|
:maxdepth: 2
|
|
|
|
func.whirlwind_tour
|
|
func.api
|
|
func.ux_limitations
|
|
func.migrating
|
|
```
|