pytorch/torch/_C/_functorch.pyi
Richard Zou 7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00

50 lines
1.6 KiB
Python

from torch import Tensor
from enum import Enum
# Defined in torch/csrc/functorch/init.cpp
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
def get_unwrapped(tensor: Tensor) -> Tensor: ...
def is_batchedtensor(tensor: Tensor) -> bool: ...
def is_functionaltensor(tensor: Tensor) -> bool: ...
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
def maybe_get_bdim(tensor: Tensor) -> int: ...
def maybe_get_level(tensor: Tensor) -> int: ...
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def set_autograd_function_allowed(allowed: bool) -> None: ...
def get_autograd_function_allowed() -> bool: ...
# Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum):
Torch: TransformType = ...
Vmap: TransformType = ...
Grad: TransformType = ...
Jvp: TransformType = ...
Functionalize: TransformType = ...
class CInterpreter:
def key(self) -> TransformType: ...
def level(self) -> int: ...
class CGradInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def lift(self, Tensor) -> Tensor: ...
def prevGradMode(self) -> bool: ...
class CVmapInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def batchSize(self) -> int: ...
class DynamicLayer:
pass
def peek_interpreter_stack() -> CInterpreter: ...
def pop_dynamic_layer_stack() -> DynamicLayer: ...
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...