Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).
Test Plan:
- Added jvp support to the *GenVmapAutogradFunction
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91211
Approved by: https://github.com/soulitzer
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).
Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation
Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests
Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90966
Approved by: https://github.com/soulitzer
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
`reductify_leaf(grad_input, ...)` is a helper function that processes a
single grad_input Tensor. The reason why we need it is:
- the grad_input has some optional bdim
- the input has some optional bdim
- if these are different, we need to coerce the grad_input into having
the same shape as the input, either by reducing or expanding the
grad_input.
Note that there is a special case in autograd that the user is allowed
to return a grad_input Tensor that is an expanded version of the
original input tensor. In this case, autograd automatically reduces
grad_input to the same shape as the input. Unfortunately this logic
doesn't work when bdims are involved, so we manually handle it in
`reductify_leaf`.
Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90965
Approved by: https://github.com/soulitzer
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
`restore_vmap` is a private helper function. It is vmap but has the
following
differences:
- instead of returning outputs, it returns an (outputs, out_dims) tuple.
out_dims is a pytree of shape shape as outputs and contains Optional[int]
specifying where the vmapped dimension, if it exists, is in the
corresponding output.
- does no validation on in_dims or inputs (vmap expects at least one
Tensor to be vmapped).
restore_vmap allows for no inputs to have the vmap dimension
- does no validation on outputs (vmap expects only Tensor outputs)
restore_vmap allows for return of arbitrary outputs (not just
Tensors)
Test Plan:
- added some simple test to test restore_vmap
- I am OK with restore_vmap not being a part of vmap right now -- the
implementation of vmap rarely changes and it is a bit difficult to
refactor vmap in a way that restore_vmap is a subroutine.
Other questions:
- Bikeshedding the `restore_vmap` name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90963
Approved by: https://github.com/samdow, https://github.com/soulitzer
This PR:
- adds VmapInterpreter.randomness. This returns the randomness option
the user provided in vmap(..., randomness=...)
- adds randomness in the info object passed to the vmap staticmethod of
autograd.Function. This is so that the user can handle random operations
on their own terms (if randomness="error", and if the autograd.Function
has random operations, then it is the user's responsiblity to raise an
error).
Test Plan:
- updated unittest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90789
Approved by: https://github.com/samdow, https://github.com/soulitzer
It turns out it is possible to break cycles by not directly importing a
module:
- there's a problem that torch.jit imports torch._ops and torch._ops
import torch.jit
- there's another problem that torch.autograd.function imports
custom_function_call but torch._functorch.autograd_function imports
torch.autograd.function
The "better" way to handle all of this is to do some large refactoring so
that torch._functorch.autograd_function imports some file that has
_SingleLevelAutogradFunction and then have torch.autograd.function
depend on torch.functorch.autograd_function... (and ditto for torch.jit
vs torch._ops), but I'm scared to move code around too much for BC
reasons and the fix in this PR works well.
Test Plan:
- import torch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90415
Approved by: https://github.com/albanD, https://github.com/soulitzer
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.
For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin
The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.
Since this logic is really close to the custom_function_call_grad, I
just put them together.
Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077
Approved by: https://github.com/albanD, https://github.com/soulitzer
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.
```py
class NumpyMul(torch.autograd.Function)
staticmethod
def forward(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
staticmethod
def setup_context(ctx, outputs, x, y):
ctx.save_for_backward(x, y)
staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if isinstance(x, torch.Tensor) and x.requires_grad:
gx = NumpyMul.apply(grad_output, y)
gy = None
if isinstance(y, torch.Tensor) and y.requires_grad:
gy = NumpyMul.apply(grad_output, x)
return gx, gy
staticmethod
def vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
result = NumpyMul.apply(x, y)
result = result.movedim(-1, 0)
return result, 0
```
API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.
Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).
At a high level, this PR:
- adds a vmap rule for custom_function_call
Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests
Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
Happy to split this PR more if it helps.
This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.
Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.
autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).
Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.
Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer