Commit Graph

174 Commits

Author SHA1 Message Date
PyTorch MergeBot
999eec8dea Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit b7e7a4cb01.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some test_transformer running on internal A100 and V100 ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196202003))
2024-06-28 06:03:54 +00:00
PyTorch MergeBot
d21993bbb8 Revert "[cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere (#129587)"
This reverts commit 7854d84acb.

Reverted https://github.com/pytorch/pytorch/pull/129587 on behalf of https://github.com/huydhn due to Sorry for revert yet another of your change but I need to revert this to cleanly revert https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196187332 ([comment](https://github.com/pytorch/pytorch/pull/129587#issuecomment-2196198756))
2024-06-28 06:01:07 +00:00
eqy
7854d84acb [cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere (#129587)
Fix for #129579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129587
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-06-28 04:42:45 +00:00
Eddie Yan
b7e7a4cb01 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-26 00:49:18 +00:00
iibrahimli
2db33054b3 Disable fast path in TransformerEncoderLayer when there are forward (pre-)hooks attached to modules (#128415)
Fixes #128413

Disable fast-path if there are forward hooks or pre-hooks.

Example failure case given in the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128415
Approved by: https://github.com/mikaylagawarecki
2024-06-21 17:38:08 +00:00
Valentine233
5da428d9eb [cpu][flash attention] fix attention mask issue (#128816)
For attention mask in flash attention:

- Fix the issue of accessing illegal memory when the last size of mask is 1.
- Add UT of attention mask for various shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128816
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-06-21 01:12:48 +00:00
PyTorch MergeBot
817ce6835b Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit 4c971932e8.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2163690162))
2024-06-12 18:47:52 +00:00
PyTorch MergeBot
7db501ba2b Revert "[cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA (#128350)"
This reverts commit 45dccfddcd.

Reverted https://github.com/pytorch/pytorch/pull/128350 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/128350#issuecomment-2163669538))
2024-06-12 18:35:18 +00:00
eqy
45dccfddcd [cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA (#128350)
CC @vedaanta-nvidia @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128350
Approved by: https://github.com/Skylion007
2024-06-11 19:22:21 +00:00
eqy
4c971932e8 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-09 06:53:34 +00:00
Xinya Zhang
d34075e0bd Add Efficient Attention support on ROCM (#124885)
This patch implements `with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):` by reusing AOTriton's accelerated SDPA implementation

Known limitations:
- Only supports MI200/MI300X GPUs
- Does not support varlen
- Does not support `CausalVariant`
- Optional arguments `causal_diagonal` and `seqlen_k` in `_efficient_attention_forward/backward` must be null
- Does not work well with inductor's SDPA rewriter. The rewriter has been updated to only use math and flash attention on ROCM.

This PR also uses a different approach of installing AOTriton binary instead of building it from source in the base docker image. More details on motivation: https://github.com/pytorch/pytorch/pull/124885#issuecomment-2153229129

`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_transformers.py` yields "55028 passed, 20784 skipped" results with this change.  [Previous result](https://hud.pytorch.org/pr/127528) of `test_transformers.py` was 0 error, 0 failure, 55229 skipped out of 75517 tests in total (the XML report does not contain total number of passed tests).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124885
Approved by: https://github.com/malfet
2024-06-08 22:41:05 +00:00
PyTorch MergeBot
a9309502af Revert "Refactoring to remove unused variable (#125252)"
This reverts commit b094622bc9.

Reverted https://github.com/pytorch/pytorch/pull/125252 on behalf of https://github.com/drisspg due to going to land codev ([comment](https://github.com/pytorch/pytorch/pull/125252#issuecomment-2089394606))
2024-05-02 01:49:57 +00:00
Apurva Jain
b094622bc9 Refactoring to remove unused variable (#125252)
Summary: Removed unused variable for running encoder

Test Plan: buck test //caffe2/test:transformers

Differential Revision: D56771972

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125252
Approved by: https://github.com/drisspg
2024-05-01 15:17:45 +00:00
Fuzzkatt
1cf62e86a4 skip various unit tests for Jetson (#122531)
skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531
Approved by: https://github.com/eqy, https://github.com/malfet
2024-04-16 01:26:26 +00:00
Aaron Gokaslan
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
William Wen
cbde0f048b [dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support (#123300)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123300
Approved by: https://github.com/jansel, https://github.com/malfet, https://github.com/zou3519
2024-04-05 20:13:17 +00:00
Xinya Zhang
b83c94339e Fix performance regression and memory storage handling of Flash Attention on ROCM (#122857)
This PR fixes the two major issues that was discovered after the initial merge of PR #121561
1. The Flash Attention support added by has severe performance regressions on regular shapes (power of two head dimensions and sequence lengths) compared with PR #115981. Its performance is worse than the math backend and only has numerical stability advantages. This PR fixes this problem.
2. There is a flaw of memory storage handling in PR #121561 which does not copy the gradients back to the designated output tensor. This PR removes the deprecated `TensorStorageSanitizer` class which is unnecessary due to the more flexible backward kernel shipped by PR #121561

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122857
Approved by: https://github.com/jeffdaily, https://github.com/drisspg
2024-03-29 16:37:24 +00:00
Xinya Zhang
12116aee68 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/huydhn
2024-03-28 00:27:38 +00:00
FEI
e08cbc0d41 update comment of test_invalid_last_dim_stride in test_transformers.py (#122679)
Fixes #122594

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122679
Approved by: https://github.com/mikaylagawarecki
2024-03-26 15:40:24 +00:00
PyTorch MergeBot
764eae9c4e Revert "Add Flash Attention support on ROCM (#121561)"
This reverts commit a37e22de70.

Reverted https://github.com/pytorch/pytorch/pull/121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm.  We are working to reland this change before 2.3 release ([comment](https://github.com/pytorch/pytorch/pull/121561#issuecomment-2007717091))
2024-03-19 17:14:28 +00:00
drisspg
42624bceb6 Fixes nan with large bf16 values (#122135)
Fixes #121558

Performance on main:
``` Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.608132004970683 | 65.90210803551601  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.75877740024589  | 64.83824399765581  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 16.465420153690506 |  67.6770955324173  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 17.398148600477725 | 68.19829455344006  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 29.053532000398263 | 99.58901099162175  |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 |  27.826815698063   | 98.05690299253911  |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 49.89655229728669  | 178.24282555375248 |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 48.840098950313404 | 174.5950729819015  |
|     1      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 505.66218036692584 | 1865.9265094902366 |
|     1      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 295.0534054543823  | 967.3831606050952  |
|     1      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.496030446141958 | 55.11070846114308  |
|     1      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.47399884648621  | 55.452342028729625 |
|     1      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 13.216444296995178 | 55.14447903260589  |
|     1      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 12.763233599252999 | 55.142355500720434 |
|     1      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 19.409965351223946 |  74.9107634765096  |
|     1      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 19.02470579952933  | 74.84168506925926  |
|     1      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 46.37695319834165  | 172.19150450546294 |
|     1      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 45.225963747361675 | 185.19691249821335 |
|     1      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 634.3090848531574  | 2249.057865119539  |
|     1      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 320.47313248040155 | 1053.0515247955916 |
|     4      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 13.448987301671878 | 63.63581650657579  |
|     4      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.509283400140703 | 63.059300999157124 |
|     4      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 19.71098779467866  | 105.55780201684684 |
|     4      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 18.264925852417946 | 105.12311349157244 |
|     4      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 45.218703348655254 | 222.87272597895935 |
|     4      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 43.55393464793451  | 230.63290398567915 |
|     4      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 134.02968645095825 | 514.6893998607993  |
|     4      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 157.13709802366793 | 624.5892751030624  |
|     4      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 1776.7079547047617 | 6353.551096981391  |
|     4      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1143.6000745743513 | 3811.8767354171723 |
|     4      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.717129248427227 | 55.35991647047922  |
|     4      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.746983398916198 | 55.76716404175386  |
|     4      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 17.255573300644752 | 106.47456656442955 |
|     4      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 16.46409669774584  | 108.07770595420152 |
|     4      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 46.63354124641045  | 213.74862996162847 |
|     4      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 47.01801469782367  | 240.78139301855117 |
|     4      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 127.76448752265424 | 508.08745552785695 |
|     4      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 168.6308984644711  | 667.2996102133766  |
|     4      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 2268.1598202325404 | 7727.2648515645415 |
|     4      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1242.8469699807465 | 4161.965740495361  |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 14.340955897932872 | 93.72280450770633  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 13.25262250029482  |  93.2030284893699  |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 27.598425600444898 | 183.23776399483904 |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 26.362583553418514 | 183.51862096460536 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 84.52303148806094  | 383.50319798337296 |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 89.41743348259479  | 432.5502900755964  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 217.76640450116247 | 943.9354750793427  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 303.0781910638325  | 1225.4394043702632 |
|     8      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 3470.8542854059488 | 12194.579601055011 |
|     8      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2268.1174043100327 | 7608.0941944383085 |
|     8      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.289720651460811 | 95.88620596332476  |
|     8      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.618648946750909 | 95.56685149436818  |
|     8      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 31.567946751601994 | 180.62468653079122 |
|     8      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 28.611703700153157 | 189.4215695792809  |
|     8      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 84.11306998459621  | 385.25596749968827 |
|     8      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 93.82540901424363  | 455.77428903197875 |
|     8      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 226.80530551588163 | 965.8026450779289  |
|     8      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 327.4116570246406  | 1312.5067745568228 |
|     8      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 4445.5064804060385 | 15020.768146496266 |
|     8      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2433.0302356975153 | 8300.016750581563  |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```

Performance on this branch:
```Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.783618393586949 | 65.59692794689909  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.064015300711617 | 56.99719698168337  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 16.629025398287922 | 68.65267595276237  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 17.462356004398313 | 68.35797848179936  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 |  29.5476081490051  | 101.22994752600789 |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 28.395320149138573 | 98.62275794148445  |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 50.50016101449728  | 181.4357690163888  |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 49.450615647947416 | 175.86063902126625 |
|     1      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 506.06461532879626 | 1866.0613044630736 |
|     1      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 299.9336270149797  | 976.4662646921353  |
|     1      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.45752210286446  | 58.79682704107836  |
|     1      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.407129396684468 | 58.14061599085107  |
|     1      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 13.822759891627355 | 56.56979401828722  |
|     1      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 13.39154909946956  |  56.7130644340068  |
|     1      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 20.282494352431968 | 77.29688903782517  |
|     1      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 19.899454596452415 |  75.4446149803698  |
|     1      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 48.494275606935844 | 177.5322465109639  |
|     1      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 46.84524350450374  | 189.1778860008344  |
|     1      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 635.1026654010639  | 2248.0451600858937 |
|     1      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 335.1591735263355  | 1080.4320796160027 |
|     4      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 13.63953539985232  | 65.50709309522063  |
|     4      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.858113402035087 | 63.021871959790595 |
|     4      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 19.98318645055406  | 105.87883047992364 |
|     4      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 18.619045056402683 | 104.90188701078296 |
|     4      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 45.91175540117546  | 226.00732848513871 |
|     4      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 44.39614630537107  | 232.39317198749632 |
|     4      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 135.5409600073472  | 522.7949097752571  |
|     4      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 158.79383607534692 | 628.5856699105352  |
|     4      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 1775.9978299727663 | 6343.203847063706  |
|     4      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1160.680354805663  | 3842.235009651631  |
|     4      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.553713708417488 | 65.50691701704638  |
|     4      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.486379051348194 |  56.9980075233616  |
|     4      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 17.56585600087419  | 107.89892700267956 |
|     4      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 16.828144202008843 | 109.05519902007653 |
|     4      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 48.23235589428805  | 217.8974545095116  |
|     4      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 49.09284680034033  | 244.73925953498107 |
|     4      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 134.77827049791813 | 522.7259948151186  |
|     4      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 176.60772847011688 | 681.5171707421541  |
|     4      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 2267.821540008299  | 7720.425300067291  |
|     4      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1295.3941145678982 | 4272.425139788538  |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 14.514714101096615 |  94.2192979855463  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 13.553097198018804 |  93.244242540095   |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 27.95821905019693  | 185.0469880155288  |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 26.709681446664035 | 184.22623950755226 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 85.85420495364815  | 388.3417735341937  |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 89.97473795898259  | 434.4228169647977  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 220.6919804448262  | 958.9654899900779  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 306.55586952343583 | 1233.2170095760375 |
|     8      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 3470.7326447824016 | 12183.611298678443 |
|     8      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2299.064100370742  | 7669.618452200666  |
|     8      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.427107692928985 | 96.96270158747211  |
|     8      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.856995843118057 | 96.38117247959599  |
|     8      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 |  32.9956392000895  | 182.52741603646427 |
|     8      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 29.397601098753512 | 191.0755339777097  |
|     8      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 89.06024845782667  | 392.2585004474967  |
|     8      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 97.78487798757851  | 462.07307645818213 |
|     8      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 |  240.521906001959  | 992.4693452194335  |
|     8      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 341.98952303268015 | 1339.2950996058062 |
|     8      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 4445.311005110853  | 15001.030603889374 |
|     8      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2535.9767401823774 | 8528.990152990447  |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```

```
{'avg_forward_time_nan_fix': 399.7900972732653,
 'avg_backward_time_nan_fix': 1409.652114014413,
 'avg_forward_time_main_branch': 394.6807206988645,
 'avg_backward_time_main_branch': 1399.4055472857629,
 'geo_mean_nan_fix': 150.95049601244946,
 'geo_mean_main_branch': 148.3381648508822}
 ```

The y axis is wrong and is micro seconds but the relative comparison still works
<img width="790" alt="Screenshot 2024-03-18 at 3 34 15 PM" src="https://github.com/pytorch/pytorch/assets/32754868/ca278c15-b815-4535-bdcd-07e522055466">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122135
Approved by: https://github.com/cpuhrsch
2024-03-19 16:32:00 +00:00
Xinya Zhang
a37e22de70 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
2024-03-12 01:16:53 +00:00
y-sq
393b4ab432 Fixes issue_119785 (#121048)
Fixes #ISSUE_119785

- Removed all sentinel files of `test_causal_variants_.*`.

- The `test_causal_variants_causal_variant_` tests could pass after removing the dynamo_skips files.

- The `test_causal_variants_compile_causal_variant` fails with `PYTORCH_TEST_WITH_DYNAMO=1`. These tests already call torch.compile, so added @skipIfTorchDynamo to skip them for `PYTORCH_TEST_WITH_DYNAMO`.

**Tests**
```
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard

test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.7745s]                  [  3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.8020s]                  [  6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0385s] (Lower righ...) [  9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.5046s]                  [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.6483s]                   [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.8537s]                   [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.8388s]                   [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.4859s]                   [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu SKIPPED [0.0084s] (Th...) [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu SKIPPED [0.0086s] (Th...) [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0081s] (Th...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu SKIPPED [0.0085s] (Th...) [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu SKIPPED [0.0082s] (Thi...) [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu SKIPPED [0.0085s] (Thi...) [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu SKIPPED [0.0081s] (Thi...) [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu SKIPPED [0.0085s] (Thi...) [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.4185s]                [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.4273s]                [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0280s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [8.0999s]                [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.3785s]                 [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.3818s]                 [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.3864s]                 [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.7668s]                 [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda SKIPPED [0.0089s] (...) [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda SKIPPED [0.0087s] (...) [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0087s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda SKIPPED [0.0084s] (...) [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda SKIPPED [0.0087s] (T...) [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda SKIPPED [0.0087s] (T...) [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda SKIPPED [0.0084s] (T...) [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda SKIPPED [0.0087s] (T...) [100%]

=================================================== 14 passed, 18 skipped, 77218 deselected in 39.72s ===================================================
```
```
$ pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard

test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.2410s]                  [  3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.3984s]                  [  6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lower righ...) [  9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.0095s]                  [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.1749s]                   [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.2138s]                   [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.2715s]                   [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0108s]                   [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.4864s]          [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.5346s]          [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lo...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.1722s]          [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.2341s]           [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.4786s]           [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.4635s]           [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0861s]           [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.7579s]                [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.0044s]                [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0007s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [9.2065s]                [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0081s]                 [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0063s]                 [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0059s]                 [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.0055s]                 [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [0.1200s]        [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.1032s]        [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0010s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [0.1151s]        [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0705s]         [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0713s]         [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0696s]         [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.1516s]         [100%]

=================================================== 28 passed, 4 skipped, 77218 deselected in 39.23s ====================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121048
Approved by: https://github.com/zou3519
2024-03-05 20:19:02 +00:00
drisspg
2e6c08a14b Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-03-04 17:36:22 +00:00
PyTorch MergeBot
1458f1de66 Revert "Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)"
This reverts commit 4b7a521856.

Reverted https://github.com/pytorch/pytorch/pull/118935 on behalf of https://github.com/atalman due to Significantly increases build time. Optimization is needed ([comment](https://github.com/pytorch/pytorch/pull/118935#issuecomment-1971723284))
2024-02-29 18:42:21 +00:00
drisspg
4b7a521856 Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-02-28 19:31:15 +00:00
Eddie Yan
702e82da28 [cuDNN][Flash Attention] Minor cleanup for cuDNN SDPA (#120750)
Cleaning up before hopefully starting work on backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120750
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-02-28 17:32:07 +00:00
Eddie Yan
cd380c794f [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-02-14 22:02:06 +00:00
atalman
244b124bb8 Add linux cpu test for 3.12 (#117853)
This is continuation of work: https://github.com/pytorch/pytorch/pull/113987

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117853
Approved by: https://github.com/albanD
2024-02-14 20:52:23 +00:00
CaoE
dfdbd73360 add Half support for flash attention (#119247)
Re-open for https://github.com/pytorch/pytorch/pull/118368.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119247
Approved by: https://github.com/drisspg, https://github.com/malfet
2024-02-07 05:57:41 +00:00
CK Luk
2ad3599a71 Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST (#118979)
Summary: Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST

Test Plan: See the one in D53154041
Reviewed By: yjhao, yanboliang, Yuzhen11

Differential Revision: D53154041

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118979
Approved by: https://github.com/yanboliang
2024-02-06 16:25:33 +00:00
Catherine Lee
f481835115 Revert "add Half support for flash attention on CPU (#118368)" (#119204)
This reverts commit a5a63db3bf.

Fixes #ISSUE_NUMBER

Reverts #118368

Got reverted internally but branch got deleted to automation didn't work

Mildly edited stack trace
```

...
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 635, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "torch/fx/experimental/proxy_tensor.py", line 995, in trace
    res = super().trace(root, concrete_args)
  File "torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/fx/_symbolic_trace.py", line 793, in trace
    (self.create_arg(fn(*args)),),
  File "torch/fx/experimental/proxy_tensor.py", line 665, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 357, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 68, in inner_fn
    outs = fn(*args)
  File "torch/_functorch/_aot_autograd/utils.py", line 161, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 618, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 593, in run_node
    result = super().run_node(n)
  File "torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "torch/fx/interpreter.py", line 274, in call_function
    return target(*args, **kwargs)
  File "torch/_ops.py", line 571, in __call__
    return self_._op(*args, **kwargs)
  File "torch/_subclasses/functional_tensor.py", line 380, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 744, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 779, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 423, in proxy_call
    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1225, in maybe_handle_decomp
    return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
  File "torch/_decomp/decompositions.py", line 4322, in scaled_dot_product_flash_attention_for_cpu
    torch._check(
  File "torch/__init__.py", line 1133, in _check
    _check_with(RuntimeError, cond, message)
  File "torch/__init__.py", line 1116, in _check_with
    raise error_type(message_evaluated)
RuntimeError: query must be FP32, FP64, BF16 but got torch.float16

While executing %_scaled_dot_product_flash_attention_for_cpu : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default](args = (%l_q_, %l_k_, %l_v_), kwargs = {attn_mask: %l_attn_mask_})
Original traceback:
  File "executorch/backends/xnnpack/partition/graphs/sdpa.py", line 34, in forward
    return torch.nn.functional.scaled_dot_product_attention(
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119204
Approved by: https://github.com/kit1980
2024-02-05 18:24:53 +00:00
CaoE
a5a63db3bf add Half support for flash attention on CPU (#118368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118368
Approved by: https://github.com/jgong5, https://github.com/Valentine233, https://github.com/drisspg
ghstack dependencies: #118367
2024-02-02 01:08:39 +00:00
drisspg
126c1621ce Add Support for CausalBias to torch compile (#116071)
Fixes #115363

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116071
Approved by: https://github.com/mlazos
2024-01-30 02:22:48 +00:00
Wei Wang
80cb6db90d [CUDA] [CI] Disable flash attention for sm87 architecture when the head dim > 192 (#117678)
Head dim > 192 requires A100/H100 (sm80 or sm90) per TORCH_CHECK [here](0c26565d5d/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp (L760)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117678
Approved by: https://github.com/eqy, https://github.com/malfet
2024-01-27 01:22:47 +00:00
drisspg
4e29f01bf2 Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)
# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
2024-01-24 22:28:04 +00:00
PyTorch MergeBot
2f84a9d37c Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)"
This reverts commit 5aa92b5090.

Reverted https://github.com/pytorch/pytorch/pull/115663 on behalf of https://github.com/PaliC due to Unfortunately, this pr breaks cuda builds internally ([comment](https://github.com/pytorch/pytorch/pull/115663#issuecomment-1899388813))
2024-01-18 23:40:30 +00:00
Eddie Yan
5aa92b5090 [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-01-18 01:20:36 +00:00
Sun, Jiayi
d9b265adaf modify the conditions as PythonModuleVariable (#116856)
## Motivation
The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error:

**torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext**

I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116856
Approved by: https://github.com/jansel
2024-01-15 11:10:57 +00:00
drisspg
19e93b85b9 Fixes last_dim stride check for singleton dimensions (#117001)
Fixes #116333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117001
Approved by: https://github.com/cpuhrsch
2024-01-10 04:46:49 +00:00
Valentine233
20c2ec9a15 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-07 04:58:23 +00:00
PyTorch MergeBot
2ccc7af028 Revert "[CPU] Add flash attention mask version (#115913)"
This reverts commit 76a3fbb709.

Reverted https://github.com/pytorch/pytorch/pull/115913 on behalf of https://github.com/zou3519 due to broke transformer test on dynamo shard ([comment](https://github.com/pytorch/pytorch/pull/115913#issuecomment-1878043389))
2024-01-05 02:39:12 +00:00
Valentine233
76a3fbb709 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-05 01:27:36 +00:00
Mikayla Gawarecki
d0cf2182ea Fix TransformerEncoderLayer for bias=False (#116760)
Fixes https://github.com/pytorch/pytorch/issues/116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`

`bias=False` was not something that `torch._transformer_encoder_layer_fwd`  was meant to work with, it was my bad that this wasn't tested as I approved https://github.com/pytorch/pytorch/pull/101687.

`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](a17de2d645/torch/nn/modules/transformer.py (L663-L677)) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate

Let me know if these approaches are preferable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116760
Approved by: https://github.com/jbschlosser
2024-01-05 00:13:10 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Mikayla Gawarecki
0f6f582c0d Add config to disable TransformerEncoder/MHA fastpath (#112212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112212
Approved by: https://github.com/jbschlosser
2024-01-02 23:59:30 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
drisspg
1e834e0e50 Fix bug in mem_eff kernel with attention mask and MQA (#116234)
# Summary

Found using the repros mentioned in this issue: #112577

After many go rounds with compute-sanitizer and eventual printf debugging I feel pretty confident that this was the underlying issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116234
Approved by: https://github.com/malfet, https://github.com/danthe3rd, https://github.com/atalman
2023-12-21 21:52:21 +00:00
drisspg
65d3dde665 Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-21 01:56:38 +00:00
PyTorch MergeBot
af8a50e656 Revert "Fix allowed dtypes for mem_eff attention (#116026)"
This reverts commit fc58909bab.

Reverted https://github.com/pytorch/pytorch/pull/116026 on behalf of https://github.com/jeanschmidt due to breaking internal windows buck builds, check internal diff for more details ([comment](https://github.com/pytorch/pytorch/pull/116026#issuecomment-1864354665))
2023-12-20 12:01:34 +00:00