Commit Graph

326 Commits

Author SHA1 Message Date
Masaki Kozuki
7a3503dfd8 Add _foreach_sign (#106343)
Rel:
- #106221

Should we add foreach of [`torch.sgn`](https://pytorch.org/docs/stable/generated/torch.sgn.html) as well?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106343
Approved by: https://github.com/janeyx99
2023-08-01 22:33:34 +00:00
Nikita Karetnikov
f23d755e1f [pt2] add meta for ormqr (#106278)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106278
Approved by: https://github.com/ezyang
2023-08-01 06:47:48 +00:00
Nikita Karetnikov
0ee3b84021 [pt2] add meta for cholesky_inverse (#106120)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106120
Approved by: https://github.com/ezyang
2023-07-29 17:16:20 +00:00
Nikita Karetnikov
80755884be [pt2] add meta for cholesky (#106115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106115
Approved by: https://github.com/Skylion007, https://github.com/ezyang
2023-07-29 17:16:20 +00:00
Nikita Karetnikov
b812e35a75 [pt2] add meta for argsort.stable, use sort samples in OpInfo (#106025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106025
Approved by: https://github.com/ezyang, https://github.com/zou3519
2023-07-27 03:49:17 +00:00
drisspg
c4b7311fc2 Meff Attn Bias (#104310)
# Summary

### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big.  At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099

### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention.  I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
    # Run a bunch of experiments
    batch_sizes = [8, 16, 32]
    num_heads = [16, 32]
    max_seq_lens = [15, 64, 128, 512, 555, 1024]
    embed_dims = [32, 64, 128]
    dtypes = [torch.float16, torch.bfloat16, torch.float32]
    pad_percentages = [None]
    backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    run_backward = True
    attn_mask = True
```

   The function calls `sdpa(input**).sum().backward()`.

   I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
   `Geomean Speedup: 1.977`

An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:
![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff)

 This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.

The full data can be found here:

[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
2023-07-26 15:51:59 +00:00
Nikita Karetnikov
0c65a2d58f [pt2] add meta for _adaptive_avg_pool3d_backward (#105816)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105816
Approved by: https://github.com/ezyang
2023-07-26 09:30:17 +00:00
Edward Z. Yang
4af9a914ab Improve FakeTensor to work with mixed meta-cpu embedding bag arguments (#105924)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105924
Approved by: https://github.com/mikaylagawarecki, https://github.com/eellison
2023-07-26 01:19:08 +00:00
Nikita Karetnikov
a4cffaae67 [pt2] add metas for _cholesky_solve_helper and cholesky_solve (#105867)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105867
Approved by: https://github.com/ezyang
2023-07-25 20:21:47 +00:00
PyTorch MergeBot
340ec1f460 Revert "Meff Attn Bias (#104310)"
This reverts commit 5453508115.

Reverted https://github.com/pytorch/pytorch/pull/104310 on behalf of https://github.com/DanilBaibak due to PR introduced cuda OOM issue ([comment](https://github.com/pytorch/pytorch/pull/104310#issuecomment-1650171538))
2023-07-25 16:37:32 +00:00
Jane Xu
5fec1f93dc Add meta registration for foreach_maximum_.List (#105864)
Will fix issues compiling for when amsgrad is True for Adam(W), see related failures in https://github.com/pytorch/benchmark/actions/runs/5628705163/job/15252867793

Also did some refactoring where common registrations could be deduplicated.

Test plan:
python test/inductor/test_compiled_optimizers.py -k test_adam

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105864
Approved by: https://github.com/albanD, https://github.com/mlazos
2023-07-25 00:39:13 +00:00
drisspg
5453508115 Meff Attn Bias (#104310)
# Summary

### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big.  At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099

### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention.  I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
    # Run a bunch of experiments
    batch_sizes = [8, 16, 32]
    num_heads = [16, 32]
    max_seq_lens = [15, 64, 128, 512, 555, 1024]
    embed_dims = [32, 64, 128]
    dtypes = [torch.float16, torch.bfloat16, torch.float32]
    pad_percentages = [None]
    backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    run_backward = True
    attn_mask = True
```

   The function calls `sdpa(input**).sum().backward()`.

   I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
   `Geomean Speedup: 1.977`

An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:
![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff)

 This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.

The full data can be found here:

[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
2023-07-24 22:19:26 +00:00
Nikita Karetnikov
45e4706aff [pt2] add decomps for multilabel_margin_loss_forward ops (#105302)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105302
Approved by: https://github.com/ezyang
2023-07-23 02:16:29 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Karetnikov
7e72126487 [pt2] add decomps for multi_margin_loss ops (#104578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104578
Approved by: https://github.com/ezyang, https://github.com/lezcano
2023-07-14 21:16:09 +00:00
Nikita Karetnikov
0a6888243b multi_margin_loss: check weight shape, make contiguous on CPU, add tests (#104852)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104852
Approved by: https://github.com/ezyang
2023-07-14 21:16:09 +00:00
Nikita Karetnikov
de67b52a88 Unify multi_margin_loss_shape_check on CPU and CUDA (#104851)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104851
Approved by: https://github.com/ezyang
2023-07-14 21:16:09 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Brian Hirsh
c6b9c31a2c [inductor] fix incorrect strides in copy() decomp, fix hf_LongFormer + hf_BigBird errors (#100115)
Fixes https://github.com/pytorch/pytorch/issues/100067, https://github.com/pytorch/pytorch/issues/98268 and https://github.com/pytorch/pytorch/issues/93428.

See the comment [here](https://github.com/pytorch/pytorch/issues/100067#issuecomment-1523856970) for details. The bug was that the decomposition that inductor uses for `aten.copy` doesn't respect the strides of the input in all cases. The fixes that I added should work, but will be pretty slow - we allocate a tensor (potentially larger than `self` if `self` is a slice), and perform an `as_strided_scatter` + `as_strided`. Longer term, stride-agnostic IR should let us remove this decomp?  cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @anijain2305 @soumith @desertfire

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100115
Approved by: https://github.com/albanD, https://github.com/ngimel
2023-07-13 14:40:57 +00:00
Michael Lazos
b99d605a30 Add meta registration for foreach_mul_ (#105107)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105107
Approved by: https://github.com/Chillee, https://github.com/voznesenskym
2023-07-13 04:45:22 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
Michael Lazos
9861c4a3f8 Add lerp decomps + meta registrations (#104866)
as title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104866
Approved by: https://github.com/janeyx99
2023-07-10 22:07:57 +00:00
Jane Xu
e25f5732c8 Add meta registrations and distributed decomps: _foreach_div_.Scalar, sqrt_.default (#104779)
This PR unblocks #104780 by resolving spmd tracing test issues and by adding meta registrations for foreach inplace ops (div_ and sqrt_)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104779
Approved by: https://github.com/fegin, https://github.com/albanD
2023-07-10 17:38:46 +00:00
Nikita Karetnikov
c00dd43e43 [pt2] add metas for multilabel_margin_loss ops (#104388)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104388
Approved by: https://github.com/ezyang
2023-07-05 13:42:22 +00:00
Nikita Karetnikov
a3aa4da154 [pt2] add metas for multi_margin_loss ops (#104236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104236
Approved by: https://github.com/ezyang
2023-07-05 13:40:05 +00:00
Nikita Karetnikov
ad58aba932 [pt2] add metas for adaptive_max_pool ops (#104167)
Fixes #103892.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104167
Approved by: https://github.com/ezyang
2023-07-05 07:02:07 +00:00
Nikita Karetnikov
b1c31b1d26 [pt2] metas and SymInt support for max_pool ops (#103951)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103951
Approved by: https://github.com/Chillee, https://github.com/kulinseth
2023-07-01 01:33:35 +00:00
Nikita Karetnikov
c4a6f86062 [pt2] add metas for max_unpool2d and max_unpool3d (#103821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103821
Approved by: https://github.com/Skylion007, https://github.com/Chillee
2023-07-01 01:33:35 +00:00
Yanbo Liang
77642da3b8 Fix broken meta registration for torch.full (#104451)
Fixes #104117

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104451
Approved by: https://github.com/eellison
2023-06-30 05:14:52 +00:00
Driss Guessous
4a008d268a REDO of dropout support for mem eff #102038 (#103704)
THIS IS A new PR with the changes from #102038 + #103201 +  plus namespacing changes to fix bug.

# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103704
Approved by: https://github.com/cpuhrsch
2023-06-26 23:05:03 +00:00
xuanqi
344bab2669 [RFC]: Functionalize assertions (#103757)
The idea here is to create do a graph mutation to:
* Create an initial dependency token at the beginning of the program.
* Replace non-functional version of assertion statements to functional version.
* The functional version of assertion statement will:
  * Accept a dependency token from output of previous functional assertion statement (or the initial dependency token if there isn't any).
  * Generate a dependency token as the output of assertion statement.
  * Augment the output to include the dependency token generated by last assertion statement.

The goal here is to:
* Form an explicit dependency chain and avoid potential reordering during other passes of compiling.
* Make the assertions a part of overall execution graph will affect the final output (or it could potentially be DCEed).

**NOTE:**
* Currently only cover `contrain_range` and WIP to support other assertions. Send out this PR to collect feedback first.
* Here it only focus on implementation itself. Will integrate it with current export in future PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103757
Approved by: https://github.com/avikchaudhuri
2023-06-24 00:23:35 +00:00
Nikita Karetnikov
e9705c52ac [pt2] add metas for _pdist_forward and _pdist_backward (#103817)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103817
Approved by: https://github.com/ezyang
2023-06-22 11:18:05 +00:00
Nikita Karetnikov
e48851033a [pt2] add metas for pad ops (#103815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103815
Approved by: https://github.com/ezyang
2023-06-22 11:18:05 +00:00
Nikita Karetnikov
c40fa8b614 [inductor] remove fft and svd ops from fake_incorrect_kernels (#103616)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103616
Approved by: https://github.com/eellison
2023-06-22 03:01:43 +00:00
Kurt Mohler
ee83c646bb Replace _prims_common.check with torch._check* (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage

Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
2023-06-21 00:46:17 +00:00
xuanqi
b27c3558a4 [RFC]: Create aten native op for constrain_range (#103346)
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
    a = x.item()
    constrain_as_size(a, 4, 7)
    return torch.empty((a, 4))

inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```

The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).

The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.

**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```

Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
2023-06-16 14:55:40 +00:00
Driss Guessous
155691a7d9 Implement meta functions for rshift and lshift (#103637)
Fixes #103606

Was using this script to exercise new code, cause I can never remember which test it is.
```
import torch

@torch.compile(fullgraph=True, dynamic=True)
def shift_right(tensor: torch.Tensor) -> torch.Tensor:
    return (tensor >> 2).to(torch.long)

def main():
    sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8)
    print(shift_right(sample_input))

if __name__ == "__main__":
    main()
```
And iterated through the error messages

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103637
Approved by: https://github.com/ezyang
2023-06-15 21:49:22 +00:00
Michael Lazos
00546333a5 Register more foreach op lowerings (#102654)
Adds the necessary foreach op lowerings for Adam

Adds two decomps for addcdiv and addcmul (need to verify that type promotion works correctly here)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102654
Approved by: https://github.com/jansel
2023-06-15 02:52:17 +00:00
PyTorch MergeBot
6ff6b49039 Revert "Register more foreach op lowerings (#102654)"
This reverts commit 05c01b9bfc.

Reverted https://github.com/pytorch/pytorch/pull/102654 on behalf of https://github.com/ZainRizvi due to This is breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/102654#issuecomment-1591639478))
2023-06-14 16:49:30 +00:00
Nikita Karetnikov
4a76fb49f3 [pt2] add metas for avg_pool3d and avg_pool3d_backward (#103392)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103392
Approved by: https://github.com/ezyang
2023-06-13 21:23:46 +00:00
Michael Lazos
05c01b9bfc Register more foreach op lowerings (#102654)
Adds the necessary foreach op lowerings for Adam

Adds two decomps for addcdiv and addcmul (need to verify that type promotion works correctly here)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102654
Approved by: https://github.com/jansel
2023-06-13 17:30:03 +00:00
Yinghai Lu
4c3799447f
Back out "Dropout support for memory efficient attention (#102038)" & "Two small mem_eff bug fixes (#103201)" (#103464)
Summary:
Original commit changeset: 04c4473d8510

Original Phabricator Diff: D46584152 & D46582033

Test Plan: Already explained in summary.

Reviewed By: yinghai

Differential Revision: D46633283

fbshipit-source-id: c23c2945408988f3c4339dfd5cd40ae46261716c

Co-authored-by: Shenxiu Liu <shenxiu@meta.com>
2023-06-12 18:56:48 -07:00
Bearnardd
2abad0c184 Add dtype check baddbmm (#102659)
Fixes part of the #100838 related to disabling support for non matching dtypes for input/batches for `baddbmm` operator.

* [x] added dtype checks
* [x] added test case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102659
Approved by: https://github.com/ngimel
2023-06-13 00:31:06 +00:00
Nikita Shulga
4cfa06f706 [BE] Deprecate has_XYZ attributes (#103279)
Use [`__getattr__`](https://peps.python.org/pep-0562/) to raise warningwhen one tries to access `has_XYZ` methods and recommend appropriate `torch.backends.XYZ` methods

Make respective properties in `torch._C` private (by prefixing them with underscore), to exclude from `from torch._C import *`.

Added `warnings.simplefilter` to workaround Python-3.11 torch.compile lineinfo issue.

Fixes https://github.com/pytorch/pytorch/issues/102484

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103279
Approved by: https://github.com/janeyx99, https://github.com/Skylion007
2023-06-10 05:17:17 +00:00
Nikita Karetnikov
2b3d955ffd [pt2] add meta and SymInt support for linalg_matrix_exp (#102945)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102945
Approved by: https://github.com/lezcano
2023-06-09 22:45:16 +00:00
Nikita Karetnikov
3a0f37735c [pt2] bug fix: invert condition in checkFloatingOrComplex (#102944)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102944
Approved by: https://github.com/lezcano
2023-06-09 22:45:16 +00:00
Driss Guessous
606fb882c4 Dropout support for memory efficient attention (#102038)
# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102038
Approved by: https://github.com/cpuhrsch
2023-06-08 21:50:12 +00:00
Yanbo Liang
686d7e4c48 [Inductor] Fix x.view(dtype) decomp and make inductor support it (#102920)
Fixes #99804

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102920
Approved by: https://github.com/jansel, https://github.com/ngimel
2023-06-07 17:10:54 +00:00
Ivan Zaitsev
821493715c Back out "Remove check from _prims_common, replace with torch._check* (#102219)", Back out "Forwatd fix for D46427687" (#103128)
Test Plan: revertitparrot

Reviewed By: malfet

Differential Revision: D46506433

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103128
Approved by: https://github.com/malfet
2023-06-07 01:41:41 +00:00
Nikita Karetnikov
ec0aa965da [pt2] add meta for _linalg_solve_ex (#102454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102454
Approved by: https://github.com/lezcano
2023-06-06 08:06:55 +00:00
Nikita Karetnikov
4bda4a7e4d [pt2] add meta for lu_unpack (#102937)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102937
Approved by: https://github.com/lezcano
2023-06-06 08:06:53 +00:00
Nikita Karetnikov
6ac3352a37 [pt2] add meta for _linalg_slogdet (#102464)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102464
Approved by: https://github.com/ezyang
2023-06-05 03:17:08 +00:00
Kurt Mohler
a84bb2709a Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-03 02:23:21 +00:00
Shunting Zhang
86c7652503 [inductor] layout optimization for conv (#99773)
convolution kernel with channels last runs much faster then kernel with contiguous inputs. The PR leverage that to optimize tensor layouts so we provide 'channels last' inputs to convolution. Some care need to be taken to not convert tensor layout between contiguous and channels last back and forth. Those extra copies hurt performance quite much.

Latest perf number [here](https://hud.pytorch.org/benchmark/compilers?startTime=Wed%2C%2024%20May%202023%2023%3A40%3A37%20GMT&stopTime=Wed%2C%2031%20May%202023%2023%3A40%3A37%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=shunting-layout-opt-19&lCommit=baa797fc100688dfb044fbcbdebcfd2591710f78&rBranch=main&rCommit=999bae0f54108ffc5b7cf2524a02a83901554b16)
- TB: 1.64x -> 1.69x
- HF: 1.79x -> 1.78x (random noise)
- TIMM: 1.51x -> 1.65x

Right now we disable layout optimization for dynamic shape since there is perf loss in that combination. Here is a GH issue to followup: https://github.com/pytorch/pytorch/issues/102670

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99773
Approved by: https://github.com/jansel
2023-06-02 21:08:18 +00:00
PyTorch MergeBot
a7efa0ce35 Revert "Remove check from _prims_common, replace with torch._check* (#102219)"
This reverts commit fb79d43649.

Reverted https://github.com/pytorch/pytorch/pull/102219 on behalf of https://github.com/malfet due to Broke lint, see https://github.com/pytorch/pytorch/actions/runs/5158949959/jobs/9293466925 ([comment](https://github.com/pytorch/pytorch/pull/102219#issuecomment-1574245414))
2023-06-02 20:00:48 +00:00
Kurt Mohler
fb79d43649 Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-02 19:13:45 +00:00
Nikita Karetnikov
0f1621df1a [pt2] fix typos in checkFloatingOrComplex errors (#102456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102456
Approved by: https://github.com/lezcano
2023-05-30 11:18:50 +00:00
Nikita Karetnikov
c3ea8cc58b [pt2] convert out params in register_meta (#101344)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101344
Approved by: https://github.com/lezcano
2023-05-27 18:38:52 +00:00
Michael Lazos
69c7f710ba Add meta registrations for some foreach ops (#102225)
as title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102225
Approved by: https://github.com/ngimel
2023-05-25 02:59:11 +00:00
Peter Bell
ce42010722 [inductor][decomp] Add aten._unsafe_index_put for unchecked indexing (#101812)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101812
Approved by: https://github.com/lezcano
2023-05-24 22:17:32 +00:00
Nikita Karetnikov
42b974e8f7 [pt2] add meta for linalg_lu_solve (#101836)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101836
Approved by: https://github.com/lezcano
2023-05-24 00:21:50 +00:00
PyTorch MergeBot
5147fe4969 Revert "[inductor][decomp] Add aten._unsafe_index_put for unchecked indexing (#101812)"
This reverts commit b9721bd705.

Reverted https://github.com/pytorch/pytorch/pull/101812 on behalf of https://github.com/osalpekar due to Causing test_nn_cuda tests to crash during runtime. More details at [D46093942](https://www.internalfb.com/diff/D46093942) ([comment](https://github.com/pytorch/pytorch/pull/101812#issuecomment-1560238085))
2023-05-23 23:06:21 +00:00
Peter Bell
b9721bd705 [inductor][decomp] Add aten._unsafe_index_put for unchecked indexing (#101812)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101812
Approved by: https://github.com/lezcano
2023-05-22 20:39:18 +00:00
drisspg
6f13d6892a Add meta support for multinomial (#101324)
# Summary
Found this when trying to compile the text gen loop of nanogpt here: b33289942b/torchbenchmark/models/nanogpt_generate/model.py (L322)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101324
Approved by: https://github.com/ngimel
2023-05-19 00:04:26 +00:00
Angela Yi
72a73ef67b Add aten.searchsorted.Tensor meta kernel (#101637)
Test Plan: CI

Differential Revision: D45933187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101637
Approved by: https://github.com/ezyang
2023-05-18 06:55:11 +00:00
Peter Bell
66e398951a [inductor/decomp] Add aten._unsafe_index to disable range checks (#101602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101602
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-05-17 23:36:24 +00:00
Nikita Karetnikov
42e65a2587 [pt2] add meta for linalg_lu_factor_ex (#101375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101375
Approved by: https://github.com/lezcano
2023-05-16 20:56:54 +00:00
kshitij12345
afea1a9fe9 [meta] error checking for inplace ops (#101532)
Fixes #100753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101532
Approved by: https://github.com/lezcano
2023-05-16 17:26:59 +00:00
Nikita Karetnikov
9eb1748b2b [pt2] add meta and SymInt support for linalg_lu (#101372)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101372
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-05-15 20:25:00 +00:00
Nikita Karetnikov
ac4cc63ae2 [pt2] add meta for linalg_ldl_solve (#101367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101367
Approved by: https://github.com/lezcano
2023-05-15 20:25:00 +00:00
Nikita Karetnikov
7dd8e08817 [pt2] add meta for linalg_ldl_factor_ex (#101362)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101362
Approved by: https://github.com/lezcano
2023-05-15 02:56:49 +00:00
Nikita Karetnikov
a8964d6377 [pt2] add meta and SymInt support for linalg_householder_product (#101315)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101315
Approved by: https://github.com/lezcano
2023-05-15 02:56:49 +00:00
Natalia Gimelshein
15a51e2012 simplify sdpa backward meta registration (#101128)
Per title.

there's an off chance that query_reshaped etc was actually discontiguous after reshape, but even in that case I'm pretty sure the computed gradients would still be contiguous, and we are properly transposing output gradients to produce correct strides.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101128
Approved by: https://github.com/drisspg
2023-05-11 03:30:07 +00:00
Nikita Karetnikov
c0d33f66c9 [pt2] remove unused meta_linalg_eigh (#100965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100965
Approved by: https://github.com/ezyang
2023-05-10 15:45:36 +00:00
Nikita Karetnikov
6abde61f8e [pt2] add meta function for _linalg_eigh (#100964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100964
Approved by: https://github.com/ezyang
2023-05-10 15:45:15 +00:00
Natalia Gimelshein
bfe5f5bbe1 [WIP] enable cuda graphs support for flash attention with dropout (#100196)
Fixes #99905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100196
Approved by: https://github.com/drisspg
2023-05-08 16:19:18 +00:00
Nikita Karetnikov
1e591a8b64 [pt2] add meta function for solve_triangular (#100829)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100829
Approved by: https://github.com/ezyang
2023-05-08 13:48:15 +00:00
Nikita Karetnikov
266c84e3ab [pt2] add meta function for linalg_qr (#100714)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100714
Approved by: https://github.com/ezyang, https://github.com/lezcano
2023-05-06 15:04:02 +00:00
Nikita Karetnikov
37f1be041a [pt2] enable svd in fake_tensor (#100130)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100130
Approved by: https://github.com/ezyang, https://github.com/lezcano
2023-05-05 06:27:59 +00:00
Michael Voznesensky
fe3ecfe0cf Add AotAutogradFallbackTests to dynamic suite (#100454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100454
Approved by: https://github.com/ezyang
2023-05-04 04:28:45 +00:00
PyTorch MergeBot
c3aa59c8f5 Revert "[WIP] enable cuda graphs support for flash attention with dropout (#100196)"
This reverts commit 32615618e4.

Reverted https://github.com/pytorch/pytorch/pull/100196 on behalf of https://github.com/clee2000 due to broke no ops build 32615618e4 https://github.com/pytorch/pytorch/actions/runs/4866578063/jobs/8678258318 ([comment](https://github.com/pytorch/pytorch/pull/100196#issuecomment-1532352810))
2023-05-03 01:41:56 +00:00
Natalia Gimelshein
32615618e4 [WIP] enable cuda graphs support for flash attention with dropout (#100196)
Fixes #99905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100196
Approved by: https://github.com/drisspg
2023-05-02 23:05:31 +00:00
Justin Chu
e779a30d50 [BE] Fix SIM109 compare-with-tuple (#100337)
Use {replacement} instead of multiple equality comparisons

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100337
Approved by: https://github.com/Skylion007
2023-04-30 19:51:32 +00:00
Tugsbayasgalan Manlaibaatar
d4bf76c2a4 Persist torch.assert in aten graph (#100101)
This PR introduces a new operator called aten._assert_async.msg, which allows passing a tensor value and assertion message as inputs. As part of TorchDynamo, we're replacing the use of torch._assert with this new operator so that make_fx also knows how to handle assertions. This is subset of https://github.com/pytorch/pytorch/pull/98878, refer there for historic reviews.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100101
Approved by: https://github.com/jansel
2023-04-28 07:31:43 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Xiaodong Wang
cc01568efd [pt2] Register meta func to randperm.default (#99593)
Summary:
Looks we're missing the meta func for randperm.default. I get complaints like this when I compile randperm with dynamic shape which I think is because it gets into the real implementation but not the meta func.

```
RuntimeError: expected int but got s0
Exception raised from expect_int at fbcode/caffe2/c10/core/SymInt.h:128 (most recent call first):
# 0  c10::get_backtrace[abi:cxx11](unsigned long, unsigned long, bool)
# 1  std::_Function_handler<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > (), c10::(anonymous namespace)::GetFetchStackTrace()::$_1>::_M_invoke(std::_Any_data const&)
# 2  c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
# 3  c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
# 4  c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__randperm>, at::Tensor, c10::guts::typelist::typelist<c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool> > >, at::Tensor (c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)
# 5  at::Tensor c10::Dispatcher::redispatch<at::Tensor, c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool> >(c10::TypedOperatorHandle<at::Tensor (c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)> const&, c10::DispatchKeySet, c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>) const
# 6  at::_ops::randperm::redispatch(c10::DispatchKeySet, c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)
# 7  c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>), &at::(anonymous namespace)::randperm>, at::Tensor, c10::guts::typelist::typelist<c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool> > >, at::Tensor (c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)
# 8  c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>), &at::(anonymous namespace)::randperm>, at::Tensor, c10::guts::typelist::typelist<c10::SymInt, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool> > >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)

```

Differential Revision: D45137851

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99593
Approved by: https://github.com/ezyang
2023-04-25 08:55:43 +00:00
Wanchao Liang
ca24a96216 minor fix to fused adam meta registration (#99436)
This PR fixes the registration by adding `max_exp_avg_sqs` to the
output shape list too, and fix some type check issue
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99436
Approved by: https://github.com/mrshenli
2023-04-24 22:50:02 +00:00
Edward Z. Yang
10c938abef Handle meta['val'] for tuple of lists. (#99724)
Fixes https://github.com/pytorch/pytorch/issues/99356

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99724
Approved by: https://github.com/wanchaol
2023-04-21 22:33:21 +00:00
Rodrigo Kumpera
38e964056b Reland python ops (#99170)
Waiting for the revert to land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99170
Approved by: https://github.com/albanD
2023-04-18 15:15:46 +00:00
PyTorch MergeBot
1c042a2137 Revert "Reland python ops (#99170)"
This reverts commit d4de64ae8d.

Reverted https://github.com/pytorch/pytorch/pull/99170 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-04-18 11:37:43 +00:00
Rodrigo Kumpera
d4de64ae8d Reland python ops (#99170)
Waiting for the revert to land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99170
Approved by: https://github.com/albanD
2023-04-17 21:53:41 +00:00
Nikita Karetnikov
106ccf4a2a [pt2] add meta function for linalg.cross (#99279)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99279
Approved by: https://github.com/ezyang
2023-04-17 21:21:45 +00:00
PyTorch MergeBot
f957334c2b Revert "[pt2] add meta function for linalg.cross (#99279)"
This reverts commit efc3887ea5.

Reverted https://github.com/pytorch/pytorch/pull/99279 on behalf of https://github.com/ezyang due to Apparently this is breaking inductor on master? So weird
2023-04-17 19:33:16 +00:00
Nikita Karetnikov
efc3887ea5 [pt2] add meta function for linalg.cross (#99279)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99279
Approved by: https://github.com/ezyang
2023-04-17 03:05:20 +00:00
Rodrigo Kumpera
a910045add [PATCH] Back out "Move functional collectives implementation to python. (#98595) (#99168)
Summary:
Original commit changeset: ba36f8751adc

Original Phabricator Diff: D44788697

Test Plan: model loading is fine after reverting the diff

Reviewed By: zyan0, sayitmemory

Differential Revision: D44921259
---

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99168
Approved by: https://github.com/izaitsevfb
2023-04-14 23:48:19 +00:00
XiaobingSuper
9c98f2ceb7 inductor: rewrite mkldnn fx fusion using pattern_matcher(binary) (#97141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97141
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
2023-04-12 06:23:03 +00:00
XiaobingSuper
c214c50355 inductor: rewrite mkldnn fx fusion using pattern_matcher(conv_unary) (#97007)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97007
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
2023-04-12 05:52:54 +00:00
Guang Yang
c377a8590b Add nonzero_static() op to pytorch to unblock export (#97417)
Summary: Add new experimental python op (`torch.nonzero_static`) for export. There is NO cuda impl included in this PR

Example:

Say input tensor is `x = torch.tensor([[1, 0], [3, 2]])`

call regular `nonzero()` on x will give you a tensor `tensor([[0, 0], [1, 0], [1, 1])`
call `nonzero_static(x, size=4)` on x will give you a tensor `tensor([[0, 0], [1, 0], [1, 1], [fill_value, fill_value])` (padded)
call `nonzero_static(x, size=2)` on x will give you a tensor `tensor([[0, 0], [1, 0])` (truncated)

Test Plan:
**Unit Tests**
```
buck test @mode/dev-nosan //caffe2/test:test_dynamo -- 'caffe2/test:test_dynamo - test_export.py::ExportTests::test_export_with_nonzero_static' -- 'caffe2/test:test_dynamo - test_misc.py::MiscTests::test_nonzero_static'
```

**PT2 Export with `nonzero_static()`**
Example of `GraphModule` in the exported graph
```
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    nonzero_static_default = torch.ops.aten.nonzero_static.default(arg0, size = 4);  arg0 = None
    return pytree.tree_unflatten([nonzero_static_default], self._out_spec)
```

Differential Revision: D44324808

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97417
Approved by: https://github.com/ezyang
2023-04-11 05:13:36 +00:00