Commit Graph

4 Commits

Author SHA1 Message Date
Richard Zou
da42eab48b Fix circular import in torch/autograd/function.py (#90415)
It turns out it is possible to break cycles by not directly importing a
module:
- there's a problem that torch.jit imports torch._ops and torch._ops
import torch.jit
- there's another problem that torch.autograd.function imports
custom_function_call but torch._functorch.autograd_function imports
torch.autograd.function

The "better" way to handle all of this is to do some large refactoring so
that torch._functorch.autograd_function imports some file that has
_SingleLevelAutogradFunction and then have torch.autograd.function
depend on torch.functorch.autograd_function... (and ditto for torch.jit
vs torch._ops), but I'm scared to move code around too much for BC
reasons and the fix in this PR works well.

Test Plan:
- import torch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90415
Approved by: https://github.com/albanD, https://github.com/soulitzer
2022-12-14 16:20:57 +00:00
Richard Zou
4809e838c1 functorch.jvp support for autograd.Function (#90077)
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077
Approved by: https://github.com/albanD, https://github.com/soulitzer
2022-12-14 16:20:53 +00:00
Richard Zou
3049d99027 autograd.Function supports vmap staticmethod (#90037)
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
Richard Zou
7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00