Commit Graph

17 Commits

Author SHA1 Message Date
Richard Zou
7aaad0b832 Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)
functorch used to have a switch that enables/disables autograd.Function.
That switch now enables/disables torch.autograd.function._SingleLevelFunction, so
I've renamed it accordingly.

We could just delete the switch because users should not be directly
working with torch.autograd.function._SingleLevelFunction. However,
it was useful for debugging when something went wrong when I was
implementing the autograd.Function <> functorch interaction, so I want
to keep it around as a debugging tool for a while since the code is
already there.

Test Plan:
- updated tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025
Approved by: https://github.com/soulitzer
2023-01-17 13:36:41 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
76a3869fc6 Support functionalization on torch.cond (#89966)
This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following:

- Output of each branch aliasing input
- In-place mutation on inputs given to each branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89966
Approved by: https://github.com/zou3519
2022-12-22 22:01:47 +00:00
Richard Zou
ffa37c9fca Add VmapInterpreter.randomness (in pyfunctorch) provide it in info object (#90789)
This PR:
- adds VmapInterpreter.randomness. This returns the randomness option
the user provided in vmap(..., randomness=...)
- adds randomness in the info object passed to the vmap staticmethod of
autograd.Function. This is so that the user can handle random operations
on their own terms (if randomness="error", and if the autograd.Function
has random operations, then it is the user's responsiblity to raise an
error).

Test Plan:
- updated unittest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90789
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-17 00:43:43 +00:00
Richard Zou
abc54f9314 Revert "Revert "[functorch] Refactor life handle storage (#90317)"" (#90856)
Adds the fix for -Wsign-compare.

See original PR (https://github.com/pytorch/pytorch/pull/90317) for
commit message
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90856
Approved by: https://github.com/samdow
2022-12-15 16:03:16 +00:00
PyTorch MergeBot
0cd69d7cda Revert "[functorch] Refactor life handle storage (#90317)"
This reverts commit 4d494986af.

Reverted https://github.com/pytorch/pytorch/pull/90317 on behalf of https://github.com/osalpekar due to Causing contbuilds to fail when pytorch is built with -Wsign-compare internally - details in [D42019543](https://www.internalfb.com/diff/D42019543)
2022-12-14 19:08:33 +00:00
Richard Zou
4809e838c1 functorch.jvp support for autograd.Function (#90077)
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077
Approved by: https://github.com/albanD, https://github.com/soulitzer
2022-12-14 16:20:53 +00:00
Richard Zou
4d494986af [functorch] Refactor life handle storage (#90317)
A "life handle" is a pointer-to-boolean that says whether or not a
TensorWrapper is alive. A TensorWrapper is alive if we are currently
inside of its corresponding transform. An Interpreter is alive if we are
currently inside of its corresponding transform. I.e., for vmap(f)(x),
the BatchedTensor(x, level=1) is alive inside of the execution of f; and
the corresponding VmapInterpreter is alive inside of f.

Previously, there was a global map of level to life handle. It is
possible to get into a state where we have multiple levels that refer to
different Interpreters (if the implementation of an operator calls into
functorch) and that messes up the global map.

This PR changes it so that
- every Interpreter holds a life handle that says if it is alive
- to construct a TensorWrapper, one must either (a) directly pass it a life
handle, or (b) one must create the TensorWrapper when the corresponding
Interpreter is on the stack (and we will automatically grab the life
handle by indexing into the DynamicLayerStack with the level)

(a) is more robust so I changed most of our C++ callsites to do that.
(b) feels a bit hacky to me, but it seems fine for now:
- It'll raise a nice error message if the interpreter isn't on the stack
- all of our Python callsites already follow this convention (we construct
TensorWrappers after pushing the Interpreter onto the stack).

The alternative to (b) is that we always do (a), which we can do in the
future if (b) runs us into any problems.

Test Plan:
- all functorch tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90317
Approved by: https://github.com/samdow
2022-12-13 14:45:18 +00:00
Richard Zou
24c3ad7851 Move private forward grad mode helpers to torch.autograd.forward_ad (#90240)
Motivation
- These were previously defined in functorch. They are not
functorch-specific, so I'm moving them to torch.autograd.forward_ad and
the autograd python bindings.
- I need this to avoid some of my cyclic import problems.

Should these be public APIs? Probably. Though this needs discussion, so
punting it to the future.

Test Plan:
- moved the tests of these from test/functorch/test_eager_transforms.py
to test/test_autograd.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90240
Approved by: https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
Richard Zou
3049d99027 autograd.Function supports vmap staticmethod (#90037)
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
Richard Zou
7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00
Richard Zou
3bc327993f PyDispatcher integration with functorch (#88785)
This PR teaches PyDispatcher and PyOperator about functorch transforms.
It is important that PyDispatcher/PyOperator dispatch with functorch
transforms, because this is our plan for higher-order operators
(operators that accept functions as arguments). Examples of these
include:
- functorch transforms over the existing cond operator (control flow)
- autograd.Function support for functorch (which I am working towards),
- AOTDispatcher (should be a higher order operator)

Concretely, the problem with teaching PyDispatcher/PyOperator about
functorch is that the stack-based dispatching logic (DynamicLayerStack)
is hidden inside the fallbacks for two dispatch keys
(DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed
fallbacks, our plan on record for that is that we need to reimplement
all of them in Python (but can call helper functions in C++ to make our
lives easier).

Instead of exposing all of what DynamicLayer{Front, Back} do to python,
this PR takes the approach of re-implementing part of the stack-based
dispatching in Python. The motivation is that this is more sane and
follows what the "ideal" implementation of functorch would have been:
- each transform should be a "mode"
- there should be no TLS dispatch key set hackery. functorch needs to do
this hackery today to re-use VariableType implementations.

This PR:
- exposes the DynamicLayerStack to Python
- The DynamicLayerStack is a stack of Interpreters.
These get exposed to Python as well.
- Interpreters can run operations (Interpreter.process) or lower them to
the next interpreter in the stack (Interpreter.lower)
- To use a PyOperator with functorch transforms, a developer needs to
register a rule for each transform (vmap, grad, jvp, ...).
- The PyOperator API is NOT user-facing. Things like autograd.Function
support for functorch will end up going through the autograd.Function
API.

Question for reviewers:
- Does this design make sense?
- I'm trying to split up the "functorch support for autograd.Function"
work into logical pieces. Would it be better if I didn't? (the full
thing is a bit long - 1000-2000 LOC).

Test Plan:
- new tests that construct PyOperator and compose them with functorch
transforms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00
Richard Zou
2268a3215c [functorch] add switch to enable autograd.Function (#88784)
This is mostly a debug or "if you know what you're doing" switch for
now. It is not public API.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88784
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00
Edward Z. Yang
d07b85393a SymInt fixes from symbolic-shapes branch (#86242)
symintify a few inplace meta functions

symintify resize_(), nbytes(), functionalization input mutations

meta funcs for avg_pool2d_backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86242
Approved by: https://github.com/Chillee
2022-10-05 04:52:02 +00:00
Richard Zou
e0170c7cde Remove torch/extension.h dependency in torch/csrc/functorch/init.cpp (#85659)
This file doesn't depend on APIs there. Required adding some
namespacing to symbols.

Test Plan:
- build & test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85659
Approved by: https://github.com/Chillee
2022-09-29 15:40:45 +00:00
Horace He
0e256c2550 removed compile cache and static argnums (#85783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85783
Approved by: https://github.com/wconstab
2022-09-28 08:33:59 +00:00
Richard Zou
848437590f Delete functorch's monkeypatching (#85430)
By upstreaming functorch's tensor printing logic into PyTorch. There's
no way of creating a custom print function for a TensorImpl subclass (as
opposed to a torch_dispatch or torch_function tensor subclass, which can
just override repr()) right now, so we need to directly interpose inside
regular Tensor printing in PyTorch.

Monkey patching is bad; users do not expect `import blah` to change
something about another library.

Fixes https://github.com/pytorch/functorch/issues/900

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85430
Approved by: https://github.com/ezyang
2022-09-22 18:47:12 +00:00
Richard Zou
5e5c319549 Move functorch python bindings to torch/csrc (#85426)
This moves functorch's python bindings to torch/csrc/functorch/init.cpp.
Coming next is the torchdim move. I didn't do torchdim yet because
moving functorch's python bindings unblocks some other things that I
want to do first.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85426
Approved by: https://github.com/ezyang
2022-09-22 18:47:12 +00:00