Commit Graph

32 Commits

Author SHA1 Message Date
Richard Zou
7aaad0b832 Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)
functorch used to have a switch that enables/disables autograd.Function.
That switch now enables/disables torch.autograd.function._SingleLevelFunction, so
I've renamed it accordingly.

We could just delete the switch because users should not be directly
working with torch.autograd.function._SingleLevelFunction. However,
it was useful for debugging when something went wrong when I was
implementing the autograd.Function <> functorch interaction, so I want
to keep it around as a debugging tool for a while since the code is
already there.

Test Plan:
- updated tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025
Approved by: https://github.com/soulitzer
2023-01-17 13:36:41 +00:00
Richard Zou
14ff58d4fa [generate_vmap_rule] Delete unused output_shapes (#92024)
We don't actually need `output_shapes` to implement
`generate_vmap_rule=True` support for autograd.Function.
- We need this in the vjp (backward) case because autograd automatically
  reduces grad_inputs to inputs and we need to replicate that behavior.
  In order to replicate that behavior, we recorded the original input
  shapes so we know how to reduce the grad_input.
- There is no such behavior for forward-mode AD, so we don't need to
  pass an `output_shapes` to reductify.

This PR simplifies the API of `reductify` and `reductify_leaf`. Instead
of accepting `input_shape_without_bdim` and `allow_expanded_grad`, we
now combine these into a single argument,
`reduce_to_input_shape_without_bdim`.
- if it is None, then we don't do anything
- if it is not-None and a shape, then we will reduce the grad to the
  provided shape.

Test Plan:
- updated original unittests
- wait for test suite
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92024
Approved by: https://github.com/soulitzer
2023-01-17 13:36:39 +00:00
Richard Zou
f5af97ef06 [autograd.Function] add nice error message for incorrect usage of vmap (#92023)
This PR:
- adds a nice error message if the user doesn't follow the API of the
  vmap staticmethod correctly. That is, the user must return two
  arguments from the vmap staticmethod API: (outputs, out_dims), and
  out_dims must be a PyTree with either the same structure as `outputs`
  our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
  but wrap_outputs_maintaining_identity was treating "None" as "This is
  not the vmap case"

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92023
Approved by: https://github.com/soulitzer
2023-01-17 13:36:37 +00:00
Richard Zou
2f9166ef89 [autograd.Function] Cleanup asymmetry in generate_vmap_rule and vmap (#91787)
This PR:
- changes generate_vmap_rule to either be True or False. Previously it
  could be True, False, or not set. This simplifies the implementation a
  bit.
- changes the vmap staticmethod to always be on the autograd.Function
  rather than sometimes defined.
  This is how the other staticmethod (forward, backward, jvp) are
  implemented and allows us to document it.

There are 4 possible states for the autograd.Function w.r.t. to the
above:
- generate_vmap_rule is True, vmap staticmethod overriden. This raises
  an error when used with vmap.
- generate_vmap_rule is False, vmap staticmethod overriden. This is
  valid.
- generate_vmap_rule is True, vmap staticmethod not overriden. This is
  valid.
- generate_vmap_rule is False, vmap staticmethod not overriden. This
  raises an error when used with vmap.

Future:
- setup_context needs the same treatment, but that's a bit tricker to
  implement.

Test Plan:
- new unittest
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91787
Approved by: https://github.com/soulitzer
2023-01-17 13:36:34 +00:00
samdow
162474d7fd [functorch] add new ensembling api, demonstrate in example (#88850)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88850
Approved by: https://github.com/zou3519
2023-01-04 00:33:14 +00:00
samdow
c5e5916fff [functorch] add functorch functional_call, update tests to test this (#89213)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89213
Approved by: https://github.com/zou3519
2023-01-04 00:33:14 +00:00
Kshiteej K
3fdbf824ae [functorch] jacrev: chunk_size=1 without vmap (#91326)
As discussed at https://github.com/pytorch/pytorch/pull/91157#discussion_r1053679272

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91326
Approved by: https://github.com/zou3519
2022-12-28 04:56:25 +00:00
soulitzer
1b2ee4d0e1 Update functorch supported autograd.Function to allow mark_dirty (#91222)
Fixes https://github.com/pytorch/pytorch/issues/90225
Uses what was originally in 32a57bcdb6

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91222
Approved by: https://github.com/zou3519
2022-12-28 03:53:47 +00:00
Brian Hirsh
c47bdd7522 *_scatter ops should preserve input stride/storage_offset (#91029)
It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029
Approved by: https://github.com/ezyang
2022-12-22 19:41:53 +00:00
Richard Zou
2f37804cae [generate_vmap_rule] Add generate_vmap_rule to autograd.Function (#90966)
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90966
Approved by: https://github.com/soulitzer
2022-12-21 00:34:44 +00:00
Richard Zou
2a55984139 [generate_vmap_rule] reductify_leaf helper function (#90965)
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

`reductify_leaf(grad_input, ...)` is a helper function that processes a
single grad_input Tensor. The reason why we need it is:
- the grad_input has some optional bdim
- the input has some optional bdim
- if these are different, we need to coerce the grad_input into having
the same shape as the input, either by reducing or expanding the
grad_input.

Note that there is a special case in autograd that the user is allowed
to return a grad_input Tensor that is an expanded version of the
original input tensor. In this case, autograd automatically reduces
grad_input to the same shape as the input. Unfortunately this logic
doesn't work when bdims are involved, so we manually handle it in
`reductify_leaf`.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90965
Approved by: https://github.com/soulitzer
2022-12-21 00:34:44 +00:00
Richard Zou
53c94ef1bb [generate_vmap_rule] Add mechanism to override ctx.saved_tensors (CtxWithSavedTensors) (#90964)
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit#heading=h.r3ckcnsh1cxt

This PR creates CtxWithSavedTensors. You can wrap a ctx object in the
backward pass of autograd.Function in CtxWithSavedTensors and specify
the saved_tensors attribute. CtxWithSavedTensor acts like the original
ctx object (all other attribute accesses are forwarded to the original ctx
object) but it has a custom saved_tensors field.

Test Plan:
- tests that you can use CtxWithSavedTensors to get a new object with
your own saved_tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90964
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-21 00:34:43 +00:00
Kshiteej K
f02e93b584 jacrev : Support chunked computation (#89376)
Ref: https://github.com/pytorch/functorch/issues/680

We introduce a kwarg `chunk_size` in `jacrev` to control whether the Jacobian computation should be chunked and if so then `chunk_size` will dictate the maximum size of the chunks used.

We try two approaches,
* Stacked Approach: Append the intermediate computation to a list and then stack those results.
* Pre-allocation Approach: Pre-allocate a zeros tensor and copy chunked computation into it.

For Memory Benchmark, see https://github.com/pytorch/pytorch/pull/89376#issuecomment-1348479098

Benchmark CPU : Performs better with more chunks/ smaller chunk_size.

NOTE: There seems to be a lot of noise for shape `(64, 64)`.

<details>

```
[----------------------------------------------- jacrev : device cpu : chunks 2 -----------------------------------------------]
                                     |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 2080     |               76.2            |          50.9        |                  80.1
      (128, 128) : chunk_size 8256   |             1172.8            |         783.3        |                1225.5
      (128, 144) : chunk_size 9288   |             1475.1            |         990.4        |                1548.3
      (144, 144) : chunk_size 10440  |             1871.3            |        1254.4        |                1971.2

Times are in milliseconds (ms).

[----------------------------------------------- jacrev : device cpu : chunks 3 ----------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1386    |               39.9            |          25.8        |                  58.8
      (128, 128) : chunk_size 5504  |             1182.6            |         782.2        |                1229.7
      (128, 144) : chunk_size 6192  |             1483.6            |         995.4        |                1550.6
      (144, 144) : chunk_size 6960  |             1879.1            |        1257.7        |                1960.5

Times are in milliseconds (ms).

[----------------------------------------------- jacrev : device cpu : chunks 4 ----------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1040    |               41.7            |          50.6        |                  29.1
      (128, 128) : chunk_size 4128  |             1171.6            |         782.3        |                1226.7
      (128, 144) : chunk_size 4644  |             1482.2            |         994.6        |                1550.9
      (144, 144) : chunk_size 5220  |             1870.2            |        1254.5        |                1961.4

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 100 ---------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 41     |               46.8            |          50.5        |                  46.4
      (128, 128) : chunk_size 165  |              622.2            |         775.2        |                 656.0
      (128, 144) : chunk_size 185  |              803.9            |         987.3        |                 866.9
      (144, 144) : chunk_size 208  |             1021.1            |        1251.2        |                1088.2

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 200 ---------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 20     |               60.9            |          50.2        |                  62.3
      (128, 128) : chunk_size 82   |              583.1            |         779.4        |                 634.3
      (128, 144) : chunk_size 92   |              834.1            |        1005.8        |                 472.3
      (144, 144) : chunk_size 104  |             1053.6            |        1277.0        |                1033.9

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 300 --------------------------------------------]
                                  |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 13    |              77.7             |          50.4        |                  79.6
      (128, 128) : chunk_size 55  |             578.9             |         782.3        |                 626.9
      (128, 144) : chunk_size 61  |             718.2             |        1024.9        |                 800.4
      (144, 144) : chunk_size 69  |             919.7             |        1313.7        |                1023.0

Times are in milliseconds (ms).
```

</details>

Benchmark CUDA: Performs better with less chunks/bigger chunk_size.

<details>

```
[--------------------------------------------- jacrev : device cuda:1 : chunks 2 ----------------------------------------------]
                                     |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 2080     |             1485.7            |         923.8        |                1632.3
      (128, 128) : chunk_size 8256   |            25390.2            |       14103.2        |               33557.4
      (128, 144) : chunk_size 9288   |              801.7            |       16854.1        |               42894.6
      (144, 144) : chunk_size 10440  |             1003.5            |       21386.5        |               59648.5

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 3
[--------------------------------------------- jacrev : device cuda:1 : chunks 3 ---------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1386    |             1474.5            |         924.5        |                1655.5
      (128, 128) : chunk_size 5504  |            25368.9            |       10156.0        |               34022.1
      (128, 144) : chunk_size 6192  |            25223.0            |       12933.7        |               56418.5
      (144, 144) : chunk_size 6960  |            24729.3            |       16367.4        |               68744.7

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 4
[--------------------------------------------- jacrev : device cuda:1 : chunks 4 ---------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1040    |             1489.2            |         924.4        |                 1679.6
      (128, 128) : chunk_size 4128  |            25370.4            |        8987.4        |                57201.3
      (128, 144) : chunk_size 4644  |            32239.1            |       10136.2        |                72406.5
      (144, 144) : chunk_size 5220  |            40994.3            |       12867.8        |               108653.4

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 100
[------------------------------------------- jacrev : device cuda:1 : chunks 100 --------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 41     |            21121.8            |         924.2        |               22753.5
      (128, 128) : chunk_size 165  |            23679.7            |       14284.4        |               26758.2
      (128, 144) : chunk_size 185  |            30082.3            |       18063.3        |               33553.5
      (144, 144) : chunk_size 208  |            38175.6            |       22839.5        |               42030.0

Times are in microseconds (us).
```

</details>

Benchmark Script

<details>

```python
import functorch
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
from torch import profiler

import math

def prod(l):
    prod = 1
    for el in l:
        prod *= el

    return prod

def fn(x, y):
    return x + y, x.sum(0)

shapes = ((64, 64), (128, 128), (128, 144), (144, 144))

for device in ('cpu', 'cuda:1'):
    if device == 'cuda:1':
        chunks = (2, 3, 4, 100,)
    else:
        chunks = (2, 3, 4, 100, 200, 300)
    for chunk in chunks:
        results = []
        for shape in shapes:
            x = torch.zeros(*shape, dtype=torch.float, device=device)
            y = x.sum()
            chunk_size = (prod(shape) + prod(shape[1:])) // chunk
            jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size)
            jacrev_fn_chunked_pre = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True)
            jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None)

            tasks = [("jacrev_fn_chunked(x, y)", "with chunk_size and stacked"),
                     ("jacrev_fn(x, y)", "without chunk_size"),
                     ("jacrev_fn_chunked_pre(x, y)", "with chunk_size and pre-allocated"),]
            timers = [Timer(stmt=stmt, label=f"jacrev : device {device} : chunks {chunk}", sub_label=f"{(shape)} : chunk_size {chunk_size}", description=desc, globals=globals()) for stmt, desc in tasks]

            for i, timer in enumerate(timers):
                results.append(
                    timer.blocked_autorange(min_run_time=2.)
                )
                print(f"\r{i + 1} / {len(timers)} : Shape {shape} : Device {device} : chunks: {chunk}", end="")
                sys.stdout.flush()

        print()
        comparison = Compare(results)
        comparison.print()
```

</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89376
Approved by: https://github.com/zou3519
2022-12-19 20:04:21 +00:00
Richard Zou
ffa37c9fca Add VmapInterpreter.randomness (in pyfunctorch) provide it in info object (#90789)
This PR:
- adds VmapInterpreter.randomness. This returns the randomness option
the user provided in vmap(..., randomness=...)
- adds randomness in the info object passed to the vmap staticmethod of
autograd.Function. This is so that the user can handle random operations
on their own terms (if randomness="error", and if the autograd.Function
has random operations, then it is the user's responsiblity to raise an
error).

Test Plan:
- updated unittest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90789
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-17 00:43:43 +00:00
Richard Zou
f21cb7d77e [pyfunctorch] Generate a more meaningful name for _SingleLevelAutogradFunction (#90418)
The API to do this is not pretty, but at least it works.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90418
Approved by: https://github.com/soulitzer
2022-12-14 16:20:57 +00:00
Richard Zou
24c3ad7851 Move private forward grad mode helpers to torch.autograd.forward_ad (#90240)
Motivation
- These were previously defined in functorch. They are not
functorch-specific, so I'm moving them to torch.autograd.forward_ad and
the autograd python bindings.
- I need this to avoid some of my cyclic import problems.

Should these be public APIs? Probably. Though this needs discussion, so
punting it to the future.

Test Plan:
- moved the tests of these from test/functorch/test_eager_transforms.py
to test/test_autograd.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90240
Approved by: https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
Richard Zou
3049d99027 autograd.Function supports vmap staticmethod (#90037)
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
Richard Zou
7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00
Richard Zou
4068c5467d [Reland] Move functorch/_src to torch/_functorch (#88756) (#90091)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091
Approved by: https://github.com/anijain2305, https://github.com/ezyang
2022-12-03 14:17:15 +00:00
PyTorch MergeBot
218d9c6e09 Revert "Move functorch/_src to torch/_functorch (#88756)"
This reverts commit 52bc5c1cfe.

Reverted https://github.com/pytorch/pytorch/pull/88756 on behalf of https://github.com/clee2000 due to broke imports in tests 52bc5c1cfe https://github.com/pytorch/pytorch/actions/runs/3574742513/jobs/6010814968 probably a landrace
2022-11-29 17:17:11 +00:00
Richard Zou
52bc5c1cfe Move functorch/_src to torch/_functorch (#88756)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang
2022-11-29 13:55:42 +00:00
Richard Zou
3bc327993f PyDispatcher integration with functorch (#88785)
This PR teaches PyDispatcher and PyOperator about functorch transforms.
It is important that PyDispatcher/PyOperator dispatch with functorch
transforms, because this is our plan for higher-order operators
(operators that accept functions as arguments). Examples of these
include:
- functorch transforms over the existing cond operator (control flow)
- autograd.Function support for functorch (which I am working towards),
- AOTDispatcher (should be a higher order operator)

Concretely, the problem with teaching PyDispatcher/PyOperator about
functorch is that the stack-based dispatching logic (DynamicLayerStack)
is hidden inside the fallbacks for two dispatch keys
(DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed
fallbacks, our plan on record for that is that we need to reimplement
all of them in Python (but can call helper functions in C++ to make our
lives easier).

Instead of exposing all of what DynamicLayer{Front, Back} do to python,
this PR takes the approach of re-implementing part of the stack-based
dispatching in Python. The motivation is that this is more sane and
follows what the "ideal" implementation of functorch would have been:
- each transform should be a "mode"
- there should be no TLS dispatch key set hackery. functorch needs to do
this hackery today to re-use VariableType implementations.

This PR:
- exposes the DynamicLayerStack to Python
- The DynamicLayerStack is a stack of Interpreters.
These get exposed to Python as well.
- Interpreters can run operations (Interpreter.process) or lower them to
the next interpreter in the stack (Interpreter.lower)
- To use a PyOperator with functorch transforms, a developer needs to
register a rule for each transform (vmap, grad, jvp, ...).
- The PyOperator API is NOT user-facing. Things like autograd.Function
support for functorch will end up going through the autograd.Function
API.

Question for reviewers:
- Does this design make sense?
- I'm trying to split up the "functorch support for autograd.Function"
work into logical pieces. Would it be better if I didn't? (the full
thing is a bit long - 1000-2000 LOC).

Test Plan:
- new tests that construct PyOperator and compose them with functorch
transforms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00
Richard Zou
2268a3215c [functorch] add switch to enable autograd.Function (#88784)
This is mostly a debug or "if you know what you're doing" switch for
now. It is not public API.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88784
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00
Sean Ross-Ross
cdb798faef _get_nested_attr should return a value in the general case (#88822)
Fixes https://github.com/pytorch/functorch/issues/1053

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88822
Approved by: https://github.com/zou3519
2022-11-14 18:39:45 +00:00
Brian Hirsh
ec4eadac5b reland "Do not use unsafe restriding for subclasses (#87610)" (#88343)
This reverts commit 5b75b19f51.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88343
Approved by: https://github.com/ezyang
2022-11-14 13:42:51 +00:00
PyTorch MergeBot
5b75b19f51 Revert "Do not use unsafe restriding for subclasses (#87610)"
This reverts commit 73379acaf3.

Reverted https://github.com/pytorch/pytorch/pull/87610 on behalf of https://github.com/mehtanirav due to [Internal breakages](https://www.internalfb.com/intern/sandcastle/job/36028797828925790/insights)
2022-11-02 16:59:02 +00:00
Brian Hirsh
73379acaf3 Do not use unsafe restriding for subclasses (#87610)
This helps convert some accuracy errors into runtime errors,
which makes it easier to debug.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87610
Approved by: https://github.com/albanD
2022-10-31 20:49:15 +00:00
Richard Zou
18f3db2963 Fix functorch tests (#87914)
Test Plan: - Run tests

Differential Revision: D40777145

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87914
Approved by: https://github.com/Chillee, https://github.com/osalpekar
2022-10-29 01:21:55 +00:00
samdow
d2d0be9a76 fix typo in per sample grad test (#87790)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87790
Approved by: https://github.com/zou3519
2022-10-27 19:43:44 +00:00
Richard Zou
ac80da2293 [functorch] add test for torch.manual_seed inside grad transform (#87233)
I can see this behavior regressing really easily, so adding a test for
it.

Test Plan:
- run test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87233
Approved by: https://github.com/Chillee
2022-10-19 18:01:55 +00:00
Richard Zou
7da018b2f8 [functorch] fix fbcode tests (#86936)
Differential Revision: [D40358418](https://our.internmc.facebook.com/intern/diff/D40358418)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86936
Approved by: https://github.com/samdow
2022-10-14 18:42:38 +00:00
Richard Zou
109f4d4453 Move functorch tests from functorch/test/* to test/functorch/* (#86623)
This is the first step described in
https://github.com/pytorch/pytorch/issues/86618 . test/functorch/* is
the final location for these tests.

Test Plan:
- Check that the functorch shards in CI are still running tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86623
Approved by: https://github.com/huydhn
2022-10-11 17:20:45 +00:00