Commit Graph

28 Commits

Author SHA1 Message Date
Guilherme Leobas
0a580da582 Add batch decomposition for torch.linalg.eigh (#110640)
Closes https://github.com/pytorch/pytorch/issues/108481

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110640
Approved by: https://github.com/kshitij12345, https://github.com/zou3519
2023-10-09 21:36:49 +00:00
kshitij12345
b8a3998c23 add batch rule for missing inplace ops (#110692)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110692
Approved by: https://github.com/ezyang
2023-10-06 20:53:28 +00:00
kshitij12345
371d8ba599 vmap: decompose real and imag instead of registering batch rule (#110508)
Clean-up

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110508
Approved by: https://github.com/zou3519
2023-10-06 06:01:12 +00:00
vfdev-5
d9fe1713c3 Enabled batch rule decompositions for upsample*.vec ops (#110333)
Follow-up PR to https://github.com/pytorch/pytorch/pull/110172
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110333
Approved by: https://github.com/zou3519
2023-10-03 06:58:18 +00:00
vfdev-5
c62be12061 Added batch rules for _upsample_bi*2d_aa and _upsample_bi*2d_aa_backward (#110172)
Description:
- Added batch rules for `_upsample_bi*2d_aa` and `_upsample_bi*2d_aa_backward`
- Added few more test cases into `sample_inputs_upsample_aten`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110172
Approved by: https://github.com/kshitij12345, https://github.com/zou3519
2023-09-28 17:42:48 +00:00
SherlockNoMad
d997969b8b [Reland] Add sym_size/stride/numel/storage_offset to native_function.yaml (#103107)
Differential Revision: D46459100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103107
Approved by: https://github.com/angelayi, https://github.com/soulitzer
2023-06-12 19:18:49 +00:00
Nikita Shulga
20cf42de2c Revert "[Reland] Add sym_size/stride/numel/storage_offset to native_function.… (#100749)"
This reverts commit bb454891ed.
2023-05-16 18:17:02 -07:00
Sherlock Huang
bb454891ed [Reland] Add sym_size/stride/numel/storage_offset to native_function.… (#100749)
…yaml (#91… (#91919)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91919 Approved by: https://github.com/ezyang

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92402

Reviewed By: ezyang

Differential Revision: D42565586

Pulled By: SherlockNoMad

fbshipit-source-id: 1c2986e45307e076d239836a1b45441a9fa3c9d9
ghstack-source-id: 969f4928486e04c57aaf98e20e3c3ca946c51613

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100749
Approved by: https://github.com/zhxchen17, https://github.com/albanD
2023-05-12 22:57:42 +00:00
Li-Huai (Allan) Lin
c0674c439c [vmap] Add max_pool3d batch rule (#99522)
Also add a helper to integrate `max_pool2d_with_indices` and `max_pool3d_with_indices`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99522
Approved by: https://github.com/zou3519
2023-04-20 05:08:19 +00:00
Li-Huai (Allan) Lin
d31a00e713 [vamp] Add max_pool1d batch_rule (#99517)
Fixes #97558

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99517
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2023-04-20 05:08:17 +00:00
kshitij12345
5e014bfbbd [vmap] ldl_factor: batch rule (#97518)
Ref https://github.com/pytorch/pytorch/issues/96855

Will look into `ldl_solve` separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97518
Approved by: https://github.com/zou3519
2023-03-25 04:37:32 +00:00
Danni Li
7711d24717 vmap support for linalg.lu_factor (#94328)
Differential Revision: D43093457

Fix #91415

### Expected behaviour

No use warning.

```python
from functorch import vmap
x = torch.randn(4, 3, 2)
z = vmap(torch.linalg.lu_factor)(x)
```
Same behaviour as for-loop:

```python
x = torch.randn(4, 3, 2)
results = []
for xi in x:
  y = torch.linalg.lu_factor(xi)
  results.append(y)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94328
Approved by: https://github.com/zou3519, https://github.com/Skylion007, https://github.com/Chillee
2023-03-23 14:18:57 +00:00
Kshiteej K
24c49dbf14 [functorch] batch rule : few decomposition ops (#96744)
Fixes https://github.com/pytorch/pytorch/issues/96741

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96744
Approved by: https://github.com/zou3519
2023-03-15 18:55:05 +00:00
Richard Zou
13011afb87 Fix vmap registration for t, t_ (#96539)
- t, t_ are not CompositeImplicitAutograd
- They were previously registered in BatchRulesDecompositions.cpp.
- The only thing that should get registered in BatchRulesDecompositions.cpp
are CompositeImplicitAutograd
- This PR moves their registrations out of there and into
BatchRulesViews.cpp.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96539
Approved by: https://github.com/srossross, https://github.com/kshitij12345, https://github.com/Chillee
2023-03-13 16:08:32 +00:00
Sean Ross-Ross
6650aac8ce move more operators to BatchRulesDecompositions (#93164)
Moving operators over to `BatchRulesDecompositions.cpp` to remove xfails. I noticed that composite-compliant does not mean inductor or vmap compliant, so I added more `isTensorSubclassLike` checks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93164
Approved by: https://github.com/lezcano, https://github.com/kshitij12345
2023-02-03 16:36:05 +00:00
PyTorch MergeBot
f7bd5d0ccb Revert "[Reland] Add sym_size/stride/numel/storage_offset to native_function.yaml (#91… (#92402)"
This reverts commit 965f4ea3ba.

Reverted https://github.com/pytorch/pytorch/pull/92402 on behalf of https://github.com/zhxchen17 due to Caused a regression for an export model.
2023-02-03 03:12:43 +00:00
Sherlock Huang
965f4ea3ba [Reland] Add sym_size/stride/numel/storage_offset to native_function.yaml (#91… (#92402)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91919
Approved by: https://github.com/ezyang

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92402
Approved by: https://github.com/ezyang
2023-02-01 04:47:49 +00:00
Khushi Agrawal
4c074ddfd2 [functorch][reland] vmap: bitwise operators (#92836)
Previous PR: #91971

Fixes: https://github.com/pytorch/functorch/issues/1069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92836
Approved by: https://github.com/Chillee
2023-01-26 06:12:47 +00:00
Sean Ross-Ross
d354499faf adding some more missing ops to vmap (#92110)
removes some xfails that were a part of https://github.com/pytorch/functorch/issues/1009 and https://github.com/pytorch/functorch/issues/1087

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92110
Approved by: https://github.com/zou3519
2023-01-25 19:43:12 +00:00
PyTorch MergeBot
7ddcf4e0c3 Revert "[functorch] vmap: bitwise operators (#91971)"
This reverts commit e54f7b3edd.

Reverted https://github.com/pytorch/pytorch/pull/91971 on behalf of https://github.com/malfet due to Broke functorch bitwise, see e54f7b3edd
2023-01-23 14:52:16 +00:00
Khushi Agrawal
e54f7b3edd [functorch] vmap: bitwise operators (#91971)
Fixes https://github.com/pytorch/functorch/issues/1069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91971
Approved by: https://github.com/kshitij12345, https://github.com/Chillee
2023-01-23 09:03:13 +00:00
Henry Cheng
b6cfd62285 vmap support for torch.linalg.vander (#91749)
Adds vmap support for torch.linalg.vander in a similar manner to how view_as_complex is implemented.

#91700

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91749
Approved by: https://github.com/lezcano
2023-01-19 14:49:54 +00:00
PyTorch MergeBot
befe815466 Revert "Add sym_size/stride/numel/storage_offset to native_function.yaml (#91919)"
This reverts commit 0388400f3f.

Reverted https://github.com/pytorch/pytorch/pull/91919 on behalf of https://github.com/atalman due to Break internal build
2023-01-17 21:03:18 +00:00
Sherlock Huang
0388400f3f Add sym_size/stride/numel/storage_offset to native_function.yaml (#91919)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91919
Approved by: https://github.com/ezyang
2023-01-17 03:39:57 +00:00
Sean Ross-Ross
0100293a7b feat: adding greater_equal Scalar variant (#91324)
Fixes https://github.com/pytorch/functorch/issues/1080

```py
import torch
from functorch import vmap

def f(x):
    return torch.greater_equal(torch.cumsum(x, dim=0), .5 * 10)

x = torch.randn([10,10])
vmap(f)(x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91324
Approved by: https://github.com/zou3519
2023-01-05 20:25:38 +00:00
Joel Schlosser
1effabe257 Support per-parameter test decoration (#91658)
Continuation of #79979.

Fixes #79161

This PR does the following:
* Expands the `parametrize_fn()` signature from returning a 3-tuple of `(test, test_name, param_kwargs)` to returning a 4-tuple of `(test, test_name, param_kwargs, decorator_fn)`. Expected signature for the addition is `decorator_fn(param_kwargs) -> List[decorator]` i.e. given the full set of test params, return a list of decorators to apply.
    * `modules`, `ops`, and `parametrize` now fit the new signature, returning `decorator_fn`s instead of applying decorators themselves.
    * `instantiate_parametrized_tests()` and `instantiate_device_type_tests()` now call the returned `decorator_fn`, passing in the full set of `param_kwargs` (after composition + `device` / `dtype` additions) and applying the returned decorators.
    * Composing multiple `parametrize_fn`s also composes the corresponding `decorator_fn`s; the composed `decorator_fn` simply concatenates the decorator lists returned by the constituents.
* Expands `DecorateInfo.is_active` to support callables:
```python
DecorateInfo(
    unittest.expectedFailure, "TestOps", "test_python_ref_executor",
    device_type='cuda', active_if=lambda params: params['executor'] == 'nvfuser'
),
```
* Adds several tests to `test/test_testing.py` ensuring proper decoration using `@parametrize`, `@modules`, and `@ops`.
* (minor) Fixes a couple `ModuleInfo` naming oddities uncovered during testing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91658
Approved by: https://github.com/malfet
2023-01-04 21:08:32 +00:00
Sean Ross-Ross
cb3204823e adding test to audit CompositeImplicitAutograd ops that do not have a batching rule (#91367)
Fixes https://github.com/pytorch/functorch/issues/1087

It looks like there are `306` rules that should be looked into
```
test/functorch/test_vmap_registrations.py .x.....xxxxxxx.x.x.x.x.x.x.x.x........xx.x.x..x.x.xxx...xxxx.x.x.x........x.........xxxxx..x..x.....xx...xx.....xxx.xxxxxxxxxxxxxxxxx.. [ 24%]
.........x.x......x.xxxxxx..x..xx.x.xxx.x.......x.xxx.xx..xxx.xxx...xxxxx.x....xxxxxxxxxxxxxxx....xx.xxx.xx.x...xx...xx...xxxxxx...xxxxx..x...xxxxxxxxxxxx..xx..xx.xx.x..xxxx..xx [ 56%]
.xx..x.x....xxxxxx.x.xx...xxxxx.xx...x..x.x.xx...xx.xxxxxx.xxxxxx..x........xxxxxxxx..xxxxxxxx..xx.xxxxxxxxxxxxxxxxxxxxxxx..........xxxx.xxxx.........xxxxxxxx..xxx..xxx.x.x.x.xx [ 88%]
xx.xxx.x......xxx.x.xxxxxxxx....x......xxxxxxxxx.xx.x.x.x.......xx                                                                                                                [100%]

=================================================================== 249 passed, 1185 deselected, 306 xfailed in 3.17s ===================================================================

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91367
Approved by: https://github.com/zou3519
2023-01-03 04:21:39 +00:00
Sean Ross-Ross
dcce5677fd Adding test when registering a batching rule for a CompositeImplicitAutograd operation (#89465)
This is a Follow on from https://github.com/pytorch/pytorch/pull/88771 which should close out https://github.com/pytorch/functorch/issues/1009 I've got another PR where I'm moving some operators over https://github.com/pytorch/pytorch/pull/89762

you can see that the new test file is being picked [run here](https://github.com/pytorch/pytorch/actions/runs/3617298059/jobs/6096218583#step:10:472)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89465
Approved by: https://github.com/zou3519
2022-12-12 16:21:07 +00:00