Commit Graph

185 Commits

Author SHA1 Message Date
yhl48
07c02b9e92 Add vmap support for smooth_l1_loss_backward (#99429)
Follow-up of #98357
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99429
Approved by: https://github.com/kshitij12345, https://github.com/zou3519
2023-04-28 10:58:07 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Li-Huai (Allan) Lin
c0674c439c [vmap] Add max_pool3d batch rule (#99522)
Also add a helper to integrate `max_pool2d_with_indices` and `max_pool3d_with_indices`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99522
Approved by: https://github.com/zou3519
2023-04-20 05:08:19 +00:00
Li-Huai (Allan) Lin
d31a00e713 [vamp] Add max_pool1d batch_rule (#99517)
Fixes #97558

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99517
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2023-04-20 05:08:17 +00:00
Li-Huai (Allan) Lin
e549ad0046 Add log_sigmoid_backward forward-AD (#99288)
Fixes #95057
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99288
Approved by: https://github.com/kshitij12345, https://github.com/albanD
2023-04-17 15:45:20 +00:00
Li-Huai (Allan) Lin
6f181aae7c [vmap] Register decomposition for huber_loss_backward (#99236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99236
Approved by: https://github.com/kshitij12345
2023-04-16 18:50:45 +00:00
kshitij12345
2c337dd934 [fix] update the condition for aliveness of TensorWrapper (#98748)
Fixes https://github.com/pytorch/pytorch/issues/95561
Fixes https://github.com/pytorch/pytorch/issues/98021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98748
Approved by: https://github.com/zou3519
2023-04-13 08:17:20 +00:00
kshitij12345
ffd76d11c9 [fix] take : backward batching rule (#95772)
Fixes https://github.com/pytorch/pytorch/issues/95738

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95772
Approved by: https://github.com/zou3519
2023-03-30 17:18:17 +00:00
kshitij12345
19dcf55a6f [functorch] .data should not work for grad, jvp, vjp (#94817)
Improve error message

Fixes https://github.com/pytorch/pytorch/issues/94514

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94817
Approved by: https://github.com/zou3519
2023-03-30 16:46:57 +00:00
Thomas Li
e1f153f3b1 Add support for copysign operator in functorch (#96018)
Fixes #91176
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96018
Approved by: https://github.com/zou3519
2023-03-27 14:20:57 +00:00
Edward Z. Yang
fa4c77e39b Rename PyOperator to HigherOrderOperator (#97493)
Twice this week I have had people confuse "operator defined with Python
operator registration aka torch.library" and "PyOperator which is used
to define control flow operators and other operators that cannot be
represented in JIT schema."  Renaming PyOperator for clarity.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97493
Approved by: https://github.com/SherlockNoMad
2023-03-24 05:04:02 +00:00
Huy Do
c5b65032ac Restore ROCm trunk jobs (#97354)
Move it back from unstable as the job looks stable now.  The one remaining flaky test I have seen is `functorch/test_ops.py::TestOperatorsCUDA::test_vmapjvpvjp_svd_cuda_float32` b04363ead4.  So I just try to skip that one on ROCm?

I will monitor the job a bit longer, and have this PR at the ready.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97354
Approved by: https://github.com/zou3519, https://github.com/ZainRizvi
2023-03-23 02:56:44 +00:00
Huy Do
244736a5a5 Mark ROCm tests as flaky (#97259)
Before https://github.com/pytorch/pytorch/pull/96464, ROCm tests in trunk are already quite flaky https://hud.pytorch.org/reliability/pytorch/pytorch?jobName=trunk%20%2F%20linux-focal-rocm5.4.2-py3.8%20%2F%20test%20(default).

After https://github.com/pytorch/pytorch/pull/96464, there is a new group of flaky failures coming from functorch.  So let's mark the test as flaky to monitor without impacting trunk.

Two flaky tests currently seeing in trunk are:

* https://github.com/pytorch/pytorch/issues/97256
* `functorch/test_memory_efficient_fusion.py` OOM

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97259
Approved by: https://github.com/malfet, https://github.com/zou3519
2023-03-21 16:55:00 +00:00
Richard Zou
5acf403088 Run functorch tests in default shards; delete functorch-specific shards (#96464)
Fixes #96347

This PR:

- Makes the functorch tests run as a part of the "default" shards
- Delete the functorch CI shard from all CI job configurations (if it exists)
- Increase the "default" shard count by 1 for each job, unless it was
previously set to 1, to accommodate the new functorch tests and not
regress time-to-signal.
- Adds a bunch of skips for ROCM and torchdynamo configurations. We can
investigate them later.

NB: I might go through some more iterations to figure out what other
skips need to be added, but this iteration of the PR seems to pass most CI.
suite.

Test Plan:
- wait for CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96464
Approved by: https://github.com/huydhn
2023-03-21 13:53:01 +00:00
Thomas Li
159145a19e Add support for torch.complex in functorch (#96032)
Fixes #91175

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96032
Approved by: https://github.com/Skylion007, https://github.com/kshitij12345, https://github.com/zou3519
2023-03-14 20:47:53 +00:00
Li-Huai (Allan) Lin
3326c14e86 Add a sample for index_fill to test framework (#91534)
Currently the index_fill test doesn't include a sample with tensor `value` input.

This PR adds one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91534
Approved by: https://github.com/ngimel
2023-03-07 08:36:04 +00:00
kshitij12345
3b966a6ce3 [autograd] disable backward/grad for complex scalar output (#92753)
Fixes https://github.com/pytorch/pytorch/issues/92750

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92753
Approved by: https://github.com/ezyang
2023-02-23 11:38:27 +00:00
Yanan Cao (PyTorch)
039b4c8809 Add meta function for _upsample_bilinear2d_aa (#94982)
Differential Revision: D43353000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94982
Approved by: https://github.com/ezyang
2023-02-19 07:11:20 +00:00
mingfeima
c620ece726 port sparse_mm.reduce to pytorch and optimize it on CPU (#83727)
### Motivation of this PR

This patch is to migrate `spmm_reduce` from `torch-sparse` (a 3rd party dependency for PyG) to `torch`, which is a response to the initial proposal for fusion of **Gather, Apply Scatter** in Message Passing of GNN inference/training. https://github.com/pytorch/pytorch/issues/71300

**GAS** is the major step for Message Passing, the behavior of **GAS** can be classified into 2 kinds depending on the storage type of `EdgeIndex` which records the connections of nodes:

* COO: the hotspot is `scatter_reduce`
* CSR: the hotspot is `spmm_reduce`

The reduce type can be choose from: "max", "mean", "max",  "min".

extend `torch.sparse.mm` with an `reduce` argument, maps to `torch.sparse_mm.reduce` internally.
`sparse_mm_reduce` is registered under the TensorTypeId of `SparseCsrCPU`, and this operator requires an internal interface `_sparse_mm_reduce_impl` which has dual outputs:
* `out` - the actual output
* `arg_out` - records output indices in the non zero elements if the reduce type is "max" or "min", this is only useful for training. So for inference, it will not be calculated.

### Performance

Benchmark on GCN for obgn-products on Xeon single socket, the workload is improved by `4.3x` with this patch.

Performance benefit for training will be bigger, the original backward impl for `sum|mean` is sequential; the original backward impl for `max|min` is not fused.

#### before:
```
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       torch_sparse::spmm_sum        97.09%       56.086s        97.09%       56.088s        6.232s             9
                 aten::linear         0.00%      85.000us         1.38%     795.485ms      88.387ms             9
                 aten::matmul         0.00%      57.000us         1.38%     795.260ms      88.362ms             9
                     aten::mm         1.38%     795.201ms         1.38%     795.203ms      88.356ms             9
                   aten::relu         0.00%      50.000us         0.76%     440.434ms      73.406ms             6
              aten::clamp_min         0.76%     440.384ms         0.76%     440.384ms      73.397ms             6
                   aten::add_         0.57%     327.801ms         0.57%     327.801ms      36.422ms             9
            aten::log_softmax         0.00%      23.000us         0.10%      55.503ms      18.501ms             3
```

#### after
```
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
               aten::spmm_sum        87.35%       11.826s        87.36%       11.827s        1.314s             9
                 aten::linear         0.00%      92.000us         5.87%     794.451ms      88.272ms             9
                 aten::matmul         0.00%      62.000us         5.87%     794.208ms      88.245ms             9
                     aten::mm         5.87%     794.143ms         5.87%     794.146ms      88.238ms             9
                   aten::relu         0.00%      53.000us         3.35%     452.977ms      75.496ms             6
              aten::clamp_min         3.35%     452.924ms         3.35%     452.924ms      75.487ms             6
                   aten::add_         2.58%     348.663ms         2.58%     348.663ms      38.740ms             9
                 aten::argmax         0.42%      57.473ms         0.42%      57.475ms      14.369ms             4
            aten::log_softmax         0.00%      22.000us         0.39%      52.605ms      17.535ms             3
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83727
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch, https://github.com/rusty1s, https://github.com/pearu
2023-02-10 15:56:40 +00:00
albanD
496c0a207b Make segment_reduce properly private. (#93166)
I am attempting not to change the aten function to reduce the amount of BC issues on the torchscript side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93166
Approved by: https://github.com/ngimel
2023-02-06 18:32:23 +00:00
Ivan Yashchuk
fba13d94a1 Remove deprecated torch.symeig (#70988)
The time has come to remove deprecated linear algebra related functions. This PR removes `torch.symeig`.

- [x] XLA PR: https://github.com/pytorch/xla/pull/4498

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70988
Approved by: https://github.com/lezcano, https://github.com/kit1980, https://github.com/malfet
2023-01-31 11:59:11 +00:00
Li-Huai (Allan) Lin
5112f44dc4 Add vmap support for torch.index_fill (#91364)
Fixes #91177

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91364
Approved by: https://github.com/zou3519
2023-01-30 08:08:33 +00:00
Sean Ross-Ross
d354499faf adding some more missing ops to vmap (#92110)
removes some xfails that were a part of https://github.com/pytorch/functorch/issues/1009 and https://github.com/pytorch/functorch/issues/1087

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92110
Approved by: https://github.com/zou3519
2023-01-25 19:43:12 +00:00
Driss Guessous
a3715efd8b Remove windows check for cmake to build Fused kernels (#91909)
# Summary
Add support for fused attention kernels (FlashAttention and memory-efficient attention) on Windows. Previously we could not do this because the fixes required c++17 to do this but we have since update the PyTorch standard.

This PR:
- Changes invocations of unsigned long to the fixed width integer type
- Adds in the #define FP16_SWITCH(COND, ...) which has been added to the flash_attention main branch
- Changes the some macros used within mem-efficient attention code in order to work around the VA_ARG discrepancy between clang/gcc and msvc. An alternative would be setting the global flag Zc:preprocessor
- Selectively applies /Zc:lambda to only the mem-efficient sources since applying this globally caused quantization files to not compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91909
Approved by: https://github.com/cpuhrsch
2023-01-25 01:21:12 +00:00
PyTorch MergeBot
acdd462b1a Revert "Remove deprecated torch.symeig (#70988)"
This reverts commit d70ed68162.

Reverted https://github.com/pytorch/pytorch/pull/70988 on behalf of https://github.com/kit1980 due to Failing XLA tests, forward fix unsuccessful
2023-01-24 19:03:40 +00:00
Ivan Yashchuk
d70ed68162 Remove deprecated torch.symeig (#70988)
The time has come to remove deprecated linear algebra related functions. This PR removes `torch.symeig`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70988
Approved by: https://github.com/lezcano, https://github.com/kit1980
2023-01-23 22:51:40 +00:00
Driss Guessous
df14650f0b [SDPA] Update SDPA API and make function Public (#92189)
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.

## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`

This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.

#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g.  (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.

The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.

Discussed with folks at FAIR/Xformers and +1 this API change.

#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
2023-01-23 20:50:46 +00:00
Henry Cheng
b6cfd62285 vmap support for torch.linalg.vander (#91749)
Adds vmap support for torch.linalg.vander in a similar manner to how view_as_complex is implemented.

#91700

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91749
Approved by: https://github.com/lezcano
2023-01-19 14:49:54 +00:00
Peter Bell
4058dedf21 Replace log(1 + x) with log1p(x) (#92114)
`log1p` offers better precision near zero since `(1 + x) - 1` truncates any
values less than the float epsilon to zero. For `soft_margin_loss` this also
requires one fewer kernel invocation which for numel=1e7 gives me a 1.2x speedup
on CUDA and a 1.1x speedup on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92114
Approved by: https://github.com/ngimel, https://github.com/lezcano
2023-01-18 10:43:56 +00:00
Richard Zou
81cc9bba5e [autograd.Function] Kill the extension feature flag (#92026)
This PR removes the autograd.Function extension feature flag. This was
previously used for development of the functorch <> autograd.Function
interaction.

It's been in master for long enough with the feature flag defaulting to
True, so it's time to remove it.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92026
Approved by: https://github.com/soulitzer
2023-01-17 13:36:42 +00:00
lezcano
484dd40022 Implement PReLU in a compositional way (#91238)
The PReLU implementation was all over the place. This lead to a number
of bugs like https://github.com/pytorch/pytorch/issues/68760.  We fix it by:
- Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel
- This second kernel is just a good-ol' pointwise kernel.
- We implement the derivative for the pointwise kernel via TI as well for speed.
- We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally

This fixes a number of issues:
- We don't perform copies any more when the inputs are not contiguous
- The derivatives are now correct
- We fix vmap and many other functorch-related issues.
- CPU and CUDA now share the relevant broadcasting logic
- The implementation is about 1/3 the length.

Fixes https://github.com/pytorch/pytorch/issues/68760
Fixes https://github.com/pytorch/pytorch/issues/89895

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91238
Approved by: https://github.com/kshitij12345, https://github.com/jbschlosser, https://github.com/albanD
2022-12-30 10:42:30 +00:00
lezcano
5b223c43ec Avoid calling allclose in the backward if there are tensor subclasses (#91444)
`allclose` it's data-dependent (returns a bool) so it does not play well
with functorch. We are skipping that check in the context of subclasses
to avoid hard errors.

Partially fixes https://github.com/pytorch/pytorch/issues/90499

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91444
Approved by: https://github.com/albanD
2022-12-28 19:12:50 +00:00
lezcano
4444138fae Add backward for complex numbers for diagonal_scatter (#91443)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91443
Approved by: https://github.com/soulitzer
2022-12-28 19:12:50 +00:00
Khushi Agrawal
f969834f68 [functorch] vmap: nansum & nanmean (#91372)
Fixes #91174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91372
Approved by: https://github.com/zou3519
2022-12-28 18:49:49 +00:00
soulitzer
1b2ee4d0e1 Update functorch supported autograd.Function to allow mark_dirty (#91222)
Fixes https://github.com/pytorch/pytorch/issues/90225
Uses what was originally in 32a57bcdb6

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91222
Approved by: https://github.com/zou3519
2022-12-28 03:53:47 +00:00
Richard Zou
e8393131ee [generate_vmap_rule] support for jvp (#91211)
Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91211
Approved by: https://github.com/soulitzer
2022-12-27 23:25:59 +00:00
Richard Zou
48e63bf69f [functorch] composition of three transform tests with jvp (#91206)
This PR adds the following tests. They will be useful as test cases for
generate_vmap_rule=True and jvp (to come soon)
- test_jvpvmap
- test_jvpvmapvmap
- test_vmapjvpvmap
- test_jvpjvpvmap
- test_jvpvjpvmap
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91206
Approved by: https://github.com/soulitzer
2022-12-27 23:25:59 +00:00
Brian Hirsh
c47bdd7522 *_scatter ops should preserve input stride/storage_offset (#91029)
It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029
Approved by: https://github.com/ezyang
2022-12-22 19:41:53 +00:00
soulitzer
b66862ba87 [autograd Function] Don't materialize forward grad for non-differentiable types (#91183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91183
Approved by: https://github.com/zou3519
2022-12-21 05:05:44 +00:00
Peter Bell
e670c261c5 Decompose fill, zero, and zeros_like (#90968)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90968
Approved by: https://github.com/ngimel
2022-12-21 00:59:50 +00:00
Richard Zou
2f37804cae [generate_vmap_rule] Add generate_vmap_rule to autograd.Function (#90966)
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90966
Approved by: https://github.com/soulitzer
2022-12-21 00:34:44 +00:00
Richard Zou
ed589dd8e4 [functorch] add composition-of-3-transform tests for autograd_function (#90962)
This PR adds the following OpInfo tests:
- vmap x vjp x vmap
- vjp x vmap x vmap
- vjp x vjp x vmap

These OpInfo tests only run for the autograd_function_db. In general,
testing composition of two transforms is sufficient to convince
ourselves that functorch works on a given operator.

The autograd.Function testing (especially the upcoming
generate_vmap_rule) didn't feel rigorous enough to me, so I added these
additional tests to convince myself.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90962
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-17 00:43:44 +00:00
Kshiteej K
e4de6ed6bb functorch: non-contig samples for test_grad (#90990)
Ref: https://github.com/pytorch/functorch/issues/1029

Before PR: (Time: ~30s)
```
================================================= 1052 passed, 264 skipped, 17373 deselected, 9 xfailed in 29.09s =================================================
```

After PR: (Time: ~43s)
```
================================================ 1042 passed, 264 skipped, 17373 deselected, 19 xfailed in 43.13s =================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90990
Approved by: https://github.com/zou3519
2022-12-16 21:27:44 +00:00
Kshiteej K
cdf4a80cc1 replace skipIf with xfailif (#90368)
Replace skips with xfails.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90368
Approved by: https://github.com/zou3519
2022-12-14 20:35:58 +00:00
Richard Zou
4809e838c1 functorch.jvp support for autograd.Function (#90077)
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077
Approved by: https://github.com/albanD, https://github.com/soulitzer
2022-12-14 16:20:53 +00:00
Richard Zou
3049d99027 autograd.Function supports vmap staticmethod (#90037)
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
soulitzer
98a9235dce Fix prelu ref when a.ndim < 2 (#89809)
Fixes https://github.com/pytorch/pytorch/issues/89560

Previously the test case for "input is 1-D or scalar + weight is not scalar" did not exist; adding it introduced some failures:
- forward AD (fixed in this PR)
- vmap (filed https://github.com/pytorch/pytorch/issues/89895)
- ref/meta (fixed this PR, though this also regresses nvFuser support)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89809
Approved by: https://github.com/ngimel
2022-12-12 23:55:31 +00:00
Richard Zou
7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00
Peter Bell
5caa27a3fd as_strided: Fix default storage_offset for reference implementation (#89513)
This fixes the default storage_offset to take it from the input. This was
previously untested, so I've also added a new OpInfo which includes samples with
non-zero storage_offsets on the input tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89513
Approved by: https://github.com/ezyang, https://github.com/ngimel
2022-12-06 22:39:21 +00:00
PyTorch MergeBot
e645771e95 Revert "as_strided: Fix default storage_offset for reference implementation (#89513)"
This reverts commit ba70a8be03.

Reverted https://github.com/pytorch/pytorch/pull/89513 on behalf of https://github.com/kit1980 due to Broke multiple workflows, 2 unexpected successes for autograd tests
2022-12-06 07:14:16 +00:00
Sean Ross-Ross
2b7fcfa399 fix: Moving operators to FuncTorchBatchedDecomposition (#89762)
Some of the easy to move operators I've moved over and removed an xfail.

I found this from the test that I implemented in https://github.com/pytorch/pytorch/pull/89465

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89762
Approved by: https://github.com/zou3519
2022-12-06 05:59:47 +00:00
Peter Bell
ba70a8be03 as_strided: Fix default storage_offset for reference implementation (#89513)
This fixes the default storage_offset to take it from the input. This was
previously untested, so I've also added a new OpInfo which includes samples with
non-zero storage_offsets on the input tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89513
Approved by: https://github.com/ezyang, https://github.com/ngimel
2022-12-06 04:07:16 +00:00
PyTorch MergeBot
8845a8f899 Revert "as_strided: Fix default storage_offset for reference implementation (#89513)"
This reverts commit eded97ac72.

Reverted https://github.com/pytorch/pytorch/pull/89513 on behalf of https://github.com/peterbell10 due to broke master
2022-12-05 17:53:23 +00:00
Peter Bell
eded97ac72 as_strided: Fix default storage_offset for reference implementation (#89513)
This fixes the default storage_offset to take it from the input. This was
previously untested, so I've also added a new OpInfo which includes samples with
non-zero storage_offsets on the input tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89513
Approved by: https://github.com/ezyang, https://github.com/ngimel
2022-12-05 15:52:49 +00:00
Richard Zou
4068c5467d [Reland] Move functorch/_src to torch/_functorch (#88756) (#90091)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091
Approved by: https://github.com/anijain2305, https://github.com/ezyang
2022-12-03 14:17:15 +00:00
Kshiteej K
dfb533ca5b add vjp test with non-contig inputs (#89375)
Ref: https://github.com/pytorch/functorch/issues/1029

We update `test_vjp` to do contiguous and non-contiguous sample testing.

Prev Time: ~32s
New Time : ~50s
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89375
Approved by: https://github.com/zou3519
2022-12-01 14:43:30 +00:00
PyTorch MergeBot
218d9c6e09 Revert "Move functorch/_src to torch/_functorch (#88756)"
This reverts commit 52bc5c1cfe.

Reverted https://github.com/pytorch/pytorch/pull/88756 on behalf of https://github.com/clee2000 due to broke imports in tests 52bc5c1cfe https://github.com/pytorch/pytorch/actions/runs/3574742513/jobs/6010814968 probably a landrace
2022-11-29 17:17:11 +00:00
Richard Zou
52bc5c1cfe Move functorch/_src to torch/_functorch (#88756)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang
2022-11-29 13:55:42 +00:00
Jane Xu
8695f0cced Rectify native_batch_norm schema by splitting it into two legit schemas (#88697)
Using the same repro from the issue (but with BatchNorm2D)

Rectifies native_batch_norm schema by splitting the schema into 2:
1. one will have NON-optional alias-able running_mean and running_var inputs
2. the other will just not have those parameters at all (no_stats variation)

**Calling for name suggestions!**

## test plan
I've added tests in test_functionalization.py as well as an entry in common_method_invocations.py for `native_batch_norm_legit`
CI should pass.

## next steps
Because of bc/fc reasons, we reroute native_batch_norm to call our new schemas ONLY through the python dispatcher, but in 2 weeks or so, we should make `native_batch_norm_legit` the official batch_norm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88697
Approved by: https://github.com/albanD
2022-11-23 23:23:17 +00:00
Driss Guessous
1d9e1fca97 Update sdp dispatch logic to enable fused backward (#89154)
# Summary
Reorganizes how the sdp dispatch logic is down in order to enable backwards for fused kernels

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89154
Approved by: https://github.com/cpuhrsch
2022-11-21 20:02:09 +00:00
PyTorch MergeBot
e1d58b1928 Revert "Update sdp dispatch logic to enable fused backward (#89154)"
This reverts commit 2e72ec7982.

Reverted https://github.com/pytorch/pytorch/pull/89154 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_sdp_math_gradcheck test breaks periodic slow gradcheck, i.e. 419ef2cdcf
2022-11-20 22:14:38 +00:00
kshitij12345
7a2930b357 add jvp test with non-contig inputs (#89131)
Ref: https://github.com/pytorch/functorch/issues/1029

We update `test_jvp` to do contiguous and non-contiguous testing in a single test.

Prev time for `test_jvp` : ~28s
New time for `test_jvp`: ~45s

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89131
Approved by: https://github.com/zou3519
2022-11-19 04:09:29 +00:00
Driss Guessous
2e72ec7982 Update sdp dispatch logic to enable fused backward (#89154)
# Summary
Reorganizes how the sdp dispatch logic is down in order to enable backwards for fused kernels

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89154
Approved by: https://github.com/cpuhrsch
2022-11-19 02:06:27 +00:00
anjali411
dc40d3f93f Add meta impl for grid_sampler_2d_backward (#88745)
TODO: add an OpInfo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88745
Approved by: https://github.com/ezyang
2022-11-16 13:01:47 +00:00
Sherlock Huang
5faa2792fa Symintify decomps for split and upsample_bilinear; Fix decomp for _softmax_backward_data and native_dropout_backward (#88761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88761
Approved by: https://github.com/ezyang
2022-11-15 13:34:45 +00:00
Khushi Agrawal
f1a5044de0 [primTorch] _refs & opinfo alpha_dropout (#87989)
Add _refs and OpInfo for `nn.functional.alpha_dropout`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87989
Approved by: https://github.com/mruberry
2022-11-14 18:18:45 +00:00
PyTorch MergeBot
eea506aee1 Revert "Symintify decomps for split and upsample_bilinear; Fix decomp for _softmax_backward_data and native_dropout_backward (#88761)"
This reverts commit 9eabcc370f.

Reverted https://github.com/pytorch/pytorch/pull/88761 on behalf of https://github.com/suo due to much broken 9eabcc370f
2022-11-14 01:58:47 +00:00
Sherlock Huang
9eabcc370f Symintify decomps for split and upsample_bilinear; Fix decomp for _softmax_backward_data and native_dropout_backward (#88761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88761
Approved by: https://github.com/ezyang
2022-11-13 21:30:53 +00:00
Brian Hirsh
a16ced03c9 reland "fix as_strided_scatter_backward (#87646)" (#88342)
This reverts commit 71fb763e54.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88342
Approved by: https://github.com/zou3519
2022-11-07 15:00:58 +00:00
PyTorch MergeBot
71fb763e54 Revert "fix as_strided_scatter_backward (#87646)"
This reverts commit f9d7985851.

Reverted https://github.com/pytorch/pytorch/pull/87646 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but I think this one or one of the PR in the stack break bionic-cuda11.7 on trunk 70782981f0
2022-11-02 16:54:36 +00:00
Brian Hirsh
f9d7985851 fix as_strided_scatter_backward (#87646)
as_strided_scatter's derivative formula was broken - instead of making a "mask" of 1's and 0's, it would effectively make a mask of 1's and uninitialized memory.

Fixes https://github.com/pytorch/pytorch/issues/88105

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87646
Approved by: https://github.com/albanD
2022-11-02 14:36:49 +00:00
kshitij12345
96aac51717 [functorch] dont compute expected output multiple times (#86202)
Fixes https://github.com/pytorch/functorch/issues/1028

Description: We update `get_fallback_and_vmap_exhaustive` to compute expected output only once as described in the issue.

NOTE: This doesn't take care of the repeated computation in `test_vmap_exhaustive` and will be followed up later.

TODO:
* [x] Benchmark and see how much difference does this make. (Comparison Table Below: [Link](https://github.com/pytorch/pytorch/pull/86202#issuecomment-1285477653))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86202
Approved by: https://github.com/zou3519
2022-10-24 22:43:11 +00:00
Richard Zou
b805e1abef [functorch] Fix torch.cat batching rule (#86932)
The bug was discovered in https://github.com/pytorch/pytorch/pull/86842.

torch.cat has an edge case where it ignores all tensors of shape [0]. So
if any of the BatchedTensors have logical shape [0] but physical shape
[B, 0], then we coerce them to shape [0] by slicing them.

Why don't we just ignore those Tensors? We need to propagate
requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of
shape [B, 0] that requires grad, then the output must require grad).

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86932
Approved by: https://github.com/Chillee
2022-10-20 18:01:31 +00:00
Peter Bell
6eeeb88172 OpInfo: Sample input cleanup (4/n) (#86324)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86324
Approved by: https://github.com/mruberry
2022-10-19 21:25:45 +00:00
Nikita Vedeneev
f2ec9fbd03 torch.ormqr: backward support (#86800)
Seems good to have, especially when neither `a` nor `tau` requires grads and/or they are pretty small in number.
Fixes https://github.com/pytorch/pytorch/issues/86267

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86800
Approved by: https://github.com/lezcano
2022-10-18 09:07:35 +00:00
Nikita Karetnikov
841995d53b [primTorch] Add refs for data conversion ops (#86561)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86561
Approved by: https://github.com/lezcano, https://github.com/mruberry, https://github.com/zou3519
2022-10-18 08:38:51 +00:00
PyTorch MergeBot
317eeb81c3 Revert "OpInfo: Sample input cleanup (4/n) (#86324)"
This reverts commit 2a6d37d23d.

Reverted https://github.com/pytorch/pytorch/pull/86324 on behalf of https://github.com/peterbell10 due to Caused tolerance issues in periodic test
2022-10-17 18:26:59 +00:00
Peter Bell
2a6d37d23d OpInfo: Sample input cleanup (4/n) (#86324)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86324
Approved by: https://github.com/mruberry
2022-10-16 19:12:44 +00:00
Peter Bell
5d6e831563 OpInfo: Sample input cleanup (3/n) (#86380)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86380
Approved by: https://github.com/mruberry
2022-10-15 22:14:09 +00:00
samdow
6ee94b572a [functorch] Add shard to run functorch tests with asan (#82164)
This adds asan testing for functorch. It was running really long (>4hrs) with test ops, so we decided that those tests are probably redundant and skipped those. This brings this test's time down to ~30 min
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82164
Approved by: https://github.com/zou3519, https://github.com/malfet, https://github.com/huydhn
2022-10-13 17:26:56 +00:00
Will Constable
b97ae59e29 Change legacy wrap_dim to work with symint == (#86842)
- previously, sizes == vector<T>({0}) failed to hit SymInt::operator==, causing a the loop to bail out too early and make an invalid call to downstream maybe_wrap_dim helper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86842
Approved by: https://github.com/Chillee, https://github.com/malfet, https://github.com/albanD
2022-10-13 15:10:46 +00:00
Nikita Shulga
1a87c25fe1 Add functorch shard to sm86-periodic workflow (#86820)
After https://github.com/pytorch/pytorch/pull/86799 was landed there shouldn't be a need to increase tolerances

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86820
Approved by: https://github.com/zou3519
2022-10-13 04:25:41 +00:00
Richard Zou
553eaaba7c Disable tf32 in functorch transform tests (#86799)
This PR applies a large hammer and disables TF32 in specific functorch transform tests. TF32 isn't precise enough to test correctness.

We could have applied a smaller hammer by disabling TF32 per-OpInfo, but that doesn't seem to have too much additional benefit (e.g. if a convolution batching rule is correct on fp32 then I would expect it to be correct under TF32 modulo precision issues because the actual sequence of PyTorch operators we invoke has not changed, only the backend did).

Test Plan:
- I tested this locally on a machine with A100 GPUs.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86799
Approved by: https://github.com/malfet
2022-10-12 19:27:17 +00:00
Nikita Shulga
9eb4f9dd17 Tweak test tolerances to be compatible with A10G (#86538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86538
Approved by: https://github.com/ngimel
2022-10-11 23:31:48 +00:00
Richard Zou
109f4d4453 Move functorch tests from functorch/test/* to test/functorch/* (#86623)
This is the first step described in
https://github.com/pytorch/pytorch/issues/86618 . test/functorch/* is
the final location for these tests.

Test Plan:
- Check that the functorch shards in CI are still running tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86623
Approved by: https://github.com/huydhn
2022-10-11 17:20:45 +00:00