mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR sets up torch.func and populates it with the following APIs: - grad - grad_and_value - vjp - jvp - jacrev - jacfwd - hessian - functionalize - vmap It also renames all instances of `functorch` in the APIs for those docs to `torch.func`. We rewrite the `__module__` fields on some of the above APIs so that the APIs fit PyTorch's public api definition. - For an API to be public, it must have a `__module__` that points to a public PyTorch submodule. However, `torch._functorch.eager_transforms` is not public due to the leading underscore. - The solution is to rewrite `__module__` to point to where the API is exposed (torch.func). This is what both Numpy and JAX do for their APIs. - h/t pmeier in https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246 for idea and code - The helper function, `exposed_in`, is confined to torch._functorch/utils for now because we're not completely sure if this should be the long-term solution. Implication for functorch.* APIs: - functorch.grad is the same object as torch.func.grad - this means that the functorch.grad docstring is actually the torch.func.grad docstring and will refer to torch.func instead of functorch. - This isn't really a problem since the plan on record is to deprecate functorch in favor of torch.func. We can fix these if we really want, but I'm not sure if a solution is worth maintaining. Test Plan: - view docs preview Future: - vmap should actually just be torch.vmap. This requires an extra step where I need to test internal callsites, so, I'm separating it into a different PR. - make_fx should be in torch.func to be consistent with `import functorch`. This one is a bit more of a headache to deal with w.r.t. public api, so going to deal with it separately. - beef up func.rst with everything else currently on the functorch documention website. func.rst is currently just an empty shell. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016 Approved by: https://github.com/samdow
23 lines
325 B
ReStructuredText
23 lines
325 B
ReStructuredText
torch.func API Reference
|
|
========================
|
|
|
|
.. currentmodule:: torch.func
|
|
|
|
.. automodule:: torch.func
|
|
|
|
Function Transforms
|
|
-------------------
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
vmap
|
|
grad
|
|
grad_and_value
|
|
vjp
|
|
jvp
|
|
jacrev
|
|
jacfwd
|
|
hessian
|
|
functionalize
|