Commit Graph

7 Commits

Author SHA1 Message Date
PyTorch MergeBot
ad70a70171 Revert "[functorch] test: try using reference_inputs in vmap tests (#91355)"
This reverts commit a51090d4b1.

Reverted https://github.com/pytorch/pytorch/pull/91355 on behalf of https://github.com/kshitij12345 due to Broke trunk
2023-01-06 09:57:21 +00:00
kshitij12345
a51090d4b1 [functorch] test: try using reference_inputs in vmap tests (#91355)
Ref https://github.com/pytorch/functorch/issues/1090

Timings:

`test_vmap_exhaustive`

After PR
```
== 1168 passed, 55 skipped, 2353 deselected, 153 xfailed in 195.07s (0:03:15) ==
```

Before PR
```
== 1134 passed, 55 skipped, 2316 deselected, 150 xfailed in 77.18s (0:01:17) ==
```

`test_op_has_batch_rule`

After PR
```
== 988 passed, 57 skipped, 2353 deselected, 331 xfailed in 144.70s (0:02:24) ==
```

Before PR
```
== 969 passed, 57 skipped, 2316 deselected, 313 xfailed in 65.86s (0:01:05) ==
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91355
Approved by: https://github.com/zou3519
2023-01-06 08:16:11 +00:00
Richard Zou
ed589dd8e4 [functorch] add composition-of-3-transform tests for autograd_function (#90962)
This PR adds the following OpInfo tests:
- vmap x vjp x vmap
- vjp x vmap x vmap
- vjp x vjp x vmap

These OpInfo tests only run for the autograd_function_db. In general,
testing composition of two transforms is sufficient to convince
ourselves that functorch works on a given operator.

The autograd.Function testing (especially the upcoming
generate_vmap_rule) didn't feel rigorous enough to me, so I added these
additional tests to convince myself.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90962
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-17 00:43:44 +00:00
Joel Schlosser
ee2475869c ModuleInfo-based tests for AOTAutograd (#90980)
Adds a set of generated tests for `AOTAutograd` using the `ModuleInfo` db, analogous to the `OpInfo`-based tests. Includes the following changes:

* Adds a `TestEagerFusionModuleInfo` test class, with both symbolic and non-symbolic tests, just like the OpInfo tests.
    * Test logic "functionalizes" the module under test and calls into the now-factored-out verification logic the OpInfo tests use to compare compiled vs. non-compiled function outputs / grads.
* Adds a `decorateForModules(decorator, module_set)` utility to `test/functorch/common_utils.py` to handle xfails, skips, etc. The pre-existing logic is specific to ops, and I didn't want to duplicate all that, so I kept additions minimal with this function.
    * Bunch of xfails to get everything passing; haven't looked deeply into all these yet. #90500 is relevant for the RNN failures.
* Fixes a bug in the `ModuleInfo` entry for `NLLLoss` to ensure sample input has the requested `requires_grad` setting (was causing spurious test failures).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90980
Approved by: https://github.com/ezyang
2022-12-16 21:43:34 +00:00
Richard Zou
7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00
kshitij12345
96aac51717 [functorch] dont compute expected output multiple times (#86202)
Fixes https://github.com/pytorch/functorch/issues/1028

Description: We update `get_fallback_and_vmap_exhaustive` to compute expected output only once as described in the issue.

NOTE: This doesn't take care of the repeated computation in `test_vmap_exhaustive` and will be followed up later.

TODO:
* [x] Benchmark and see how much difference does this make. (Comparison Table Below: [Link](https://github.com/pytorch/pytorch/pull/86202#issuecomment-1285477653))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86202
Approved by: https://github.com/zou3519
2022-10-24 22:43:11 +00:00
Richard Zou
109f4d4453 Move functorch tests from functorch/test/* to test/functorch/* (#86623)
This is the first step described in
https://github.com/pytorch/pytorch/issues/86618 . test/functorch/* is
the final location for these tests.

Test Plan:
- Check that the functorch shards in CI are still running tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86623
Approved by: https://github.com/huydhn
2022-10-11 17:20:45 +00:00