Commit Graph

205 Commits

Author SHA1 Message Date
PyTorch MergeBot
764eae9c4e Revert "Add Flash Attention support on ROCM (#121561)"
This reverts commit a37e22de70.

Reverted https://github.com/pytorch/pytorch/pull/121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm.  We are working to reland this change before 2.3 release ([comment](https://github.com/pytorch/pytorch/pull/121561#issuecomment-2007717091))
2024-03-19 17:14:28 +00:00
drisspg
42624bceb6 Fixes nan with large bf16 values (#122135)
Fixes #121558

Performance on main:
``` Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.608132004970683 | 65.90210803551601  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.75877740024589  | 64.83824399765581  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 16.465420153690506 |  67.6770955324173  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 17.398148600477725 | 68.19829455344006  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 29.053532000398263 | 99.58901099162175  |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 |  27.826815698063   | 98.05690299253911  |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 49.89655229728669  | 178.24282555375248 |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 48.840098950313404 | 174.5950729819015  |
|     1      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 505.66218036692584 | 1865.9265094902366 |
|     1      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 295.0534054543823  | 967.3831606050952  |
|     1      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.496030446141958 | 55.11070846114308  |
|     1      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.47399884648621  | 55.452342028729625 |
|     1      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 13.216444296995178 | 55.14447903260589  |
|     1      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 12.763233599252999 | 55.142355500720434 |
|     1      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 19.409965351223946 |  74.9107634765096  |
|     1      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 19.02470579952933  | 74.84168506925926  |
|     1      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 46.37695319834165  | 172.19150450546294 |
|     1      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 45.225963747361675 | 185.19691249821335 |
|     1      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 634.3090848531574  | 2249.057865119539  |
|     1      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 320.47313248040155 | 1053.0515247955916 |
|     4      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 13.448987301671878 | 63.63581650657579  |
|     4      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.509283400140703 | 63.059300999157124 |
|     4      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 19.71098779467866  | 105.55780201684684 |
|     4      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 18.264925852417946 | 105.12311349157244 |
|     4      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 45.218703348655254 | 222.87272597895935 |
|     4      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 43.55393464793451  | 230.63290398567915 |
|     4      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 134.02968645095825 | 514.6893998607993  |
|     4      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 157.13709802366793 | 624.5892751030624  |
|     4      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 1776.7079547047617 | 6353.551096981391  |
|     4      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1143.6000745743513 | 3811.8767354171723 |
|     4      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.717129248427227 | 55.35991647047922  |
|     4      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.746983398916198 | 55.76716404175386  |
|     4      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 17.255573300644752 | 106.47456656442955 |
|     4      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 16.46409669774584  | 108.07770595420152 |
|     4      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 46.63354124641045  | 213.74862996162847 |
|     4      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 47.01801469782367  | 240.78139301855117 |
|     4      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 127.76448752265424 | 508.08745552785695 |
|     4      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 168.6308984644711  | 667.2996102133766  |
|     4      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 2268.1598202325404 | 7727.2648515645415 |
|     4      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1242.8469699807465 | 4161.965740495361  |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 14.340955897932872 | 93.72280450770633  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 13.25262250029482  |  93.2030284893699  |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 27.598425600444898 | 183.23776399483904 |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 26.362583553418514 | 183.51862096460536 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 84.52303148806094  | 383.50319798337296 |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 89.41743348259479  | 432.5502900755964  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 217.76640450116247 | 943.9354750793427  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 303.0781910638325  | 1225.4394043702632 |
|     8      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 3470.8542854059488 | 12194.579601055011 |
|     8      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2268.1174043100327 | 7608.0941944383085 |
|     8      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.289720651460811 | 95.88620596332476  |
|     8      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.618648946750909 | 95.56685149436818  |
|     8      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 31.567946751601994 | 180.62468653079122 |
|     8      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 28.611703700153157 | 189.4215695792809  |
|     8      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 84.11306998459621  | 385.25596749968827 |
|     8      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 93.82540901424363  | 455.77428903197875 |
|     8      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 226.80530551588163 | 965.8026450779289  |
|     8      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 327.4116570246406  | 1312.5067745568228 |
|     8      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 4445.5064804060385 | 15020.768146496266 |
|     8      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2433.0302356975153 | 8300.016750581563  |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```

Performance on this branch:
```Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.783618393586949 | 65.59692794689909  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.064015300711617 | 56.99719698168337  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 16.629025398287922 | 68.65267595276237  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 17.462356004398313 | 68.35797848179936  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 |  29.5476081490051  | 101.22994752600789 |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 28.395320149138573 | 98.62275794148445  |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 50.50016101449728  | 181.4357690163888  |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 49.450615647947416 | 175.86063902126625 |
|     1      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 506.06461532879626 | 1866.0613044630736 |
|     1      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 299.9336270149797  | 976.4662646921353  |
|     1      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.45752210286446  | 58.79682704107836  |
|     1      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.407129396684468 | 58.14061599085107  |
|     1      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 13.822759891627355 | 56.56979401828722  |
|     1      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 13.39154909946956  |  56.7130644340068  |
|     1      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 20.282494352431968 | 77.29688903782517  |
|     1      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 19.899454596452415 |  75.4446149803698  |
|     1      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 48.494275606935844 | 177.5322465109639  |
|     1      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 46.84524350450374  | 189.1778860008344  |
|     1      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 635.1026654010639  | 2248.0451600858937 |
|     1      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 335.1591735263355  | 1080.4320796160027 |
|     4      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 13.63953539985232  | 65.50709309522063  |
|     4      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.858113402035087 | 63.021871959790595 |
|     4      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 19.98318645055406  | 105.87883047992364 |
|     4      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 18.619045056402683 | 104.90188701078296 |
|     4      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 45.91175540117546  | 226.00732848513871 |
|     4      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 44.39614630537107  | 232.39317198749632 |
|     4      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 135.5409600073472  | 522.7949097752571  |
|     4      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 158.79383607534692 | 628.5856699105352  |
|     4      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 1775.9978299727663 | 6343.203847063706  |
|     4      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1160.680354805663  | 3842.235009651631  |
|     4      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.553713708417488 | 65.50691701704638  |
|     4      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.486379051348194 |  56.9980075233616  |
|     4      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 17.56585600087419  | 107.89892700267956 |
|     4      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 16.828144202008843 | 109.05519902007653 |
|     4      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 48.23235589428805  | 217.8974545095116  |
|     4      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 49.09284680034033  | 244.73925953498107 |
|     4      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 134.77827049791813 | 522.7259948151186  |
|     4      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 176.60772847011688 | 681.5171707421541  |
|     4      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 2267.821540008299  | 7720.425300067291  |
|     4      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1295.3941145678982 | 4272.425139788538  |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 14.514714101096615 |  94.2192979855463  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 13.553097198018804 |  93.244242540095   |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 27.95821905019693  | 185.0469880155288  |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 26.709681446664035 | 184.22623950755226 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 85.85420495364815  | 388.3417735341937  |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 89.97473795898259  | 434.4228169647977  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 220.6919804448262  | 958.9654899900779  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 306.55586952343583 | 1233.2170095760375 |
|     8      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 3470.7326447824016 | 12183.611298678443 |
|     8      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2299.064100370742  | 7669.618452200666  |
|     8      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.427107692928985 | 96.96270158747211  |
|     8      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.856995843118057 | 96.38117247959599  |
|     8      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 |  32.9956392000895  | 182.52741603646427 |
|     8      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 29.397601098753512 | 191.0755339777097  |
|     8      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 89.06024845782667  | 392.2585004474967  |
|     8      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 97.78487798757851  | 462.07307645818213 |
|     8      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 |  240.521906001959  | 992.4693452194335  |
|     8      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 341.98952303268015 | 1339.2950996058062 |
|     8      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 4445.311005110853  | 15001.030603889374 |
|     8      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2535.9767401823774 | 8528.990152990447  |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```

```
{'avg_forward_time_nan_fix': 399.7900972732653,
 'avg_backward_time_nan_fix': 1409.652114014413,
 'avg_forward_time_main_branch': 394.6807206988645,
 'avg_backward_time_main_branch': 1399.4055472857629,
 'geo_mean_nan_fix': 150.95049601244946,
 'geo_mean_main_branch': 148.3381648508822}
 ```

The y axis is wrong and is micro seconds but the relative comparison still works
<img width="790" alt="Screenshot 2024-03-18 at 3 34 15 PM" src="https://github.com/pytorch/pytorch/assets/32754868/ca278c15-b815-4535-bdcd-07e522055466">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122135
Approved by: https://github.com/cpuhrsch
2024-03-19 16:32:00 +00:00
Xinya Zhang
a37e22de70 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
2024-03-12 01:16:53 +00:00
y-sq
393b4ab432 Fixes issue_119785 (#121048)
Fixes #ISSUE_119785

- Removed all sentinel files of `test_causal_variants_.*`.

- The `test_causal_variants_causal_variant_` tests could pass after removing the dynamo_skips files.

- The `test_causal_variants_compile_causal_variant` fails with `PYTORCH_TEST_WITH_DYNAMO=1`. These tests already call torch.compile, so added @skipIfTorchDynamo to skip them for `PYTORCH_TEST_WITH_DYNAMO`.

**Tests**
```
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard

test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.7745s]                  [  3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.8020s]                  [  6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0385s] (Lower righ...) [  9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.5046s]                  [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.6483s]                   [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.8537s]                   [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.8388s]                   [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.4859s]                   [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu SKIPPED [0.0084s] (Th...) [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu SKIPPED [0.0086s] (Th...) [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0081s] (Th...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu SKIPPED [0.0085s] (Th...) [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu SKIPPED [0.0082s] (Thi...) [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu SKIPPED [0.0085s] (Thi...) [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu SKIPPED [0.0081s] (Thi...) [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu SKIPPED [0.0085s] (Thi...) [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.4185s]                [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.4273s]                [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0280s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [8.0999s]                [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.3785s]                 [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.3818s]                 [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.3864s]                 [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.7668s]                 [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda SKIPPED [0.0089s] (...) [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda SKIPPED [0.0087s] (...) [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0087s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda SKIPPED [0.0084s] (...) [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda SKIPPED [0.0087s] (T...) [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda SKIPPED [0.0087s] (T...) [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda SKIPPED [0.0084s] (T...) [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda SKIPPED [0.0087s] (T...) [100%]

=================================================== 14 passed, 18 skipped, 77218 deselected in 39.72s ===================================================
```
```
$ pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard

test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.2410s]                  [  3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.3984s]                  [  6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lower righ...) [  9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.0095s]                  [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.1749s]                   [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.2138s]                   [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.2715s]                   [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0108s]                   [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.4864s]          [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.5346s]          [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lo...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.1722s]          [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.2341s]           [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.4786s]           [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.4635s]           [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0861s]           [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.7579s]                [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.0044s]                [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0007s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [9.2065s]                [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0081s]                 [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0063s]                 [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0059s]                 [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.0055s]                 [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [0.1200s]        [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.1032s]        [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0010s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [0.1151s]        [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0705s]         [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0713s]         [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0696s]         [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.1516s]         [100%]

=================================================== 28 passed, 4 skipped, 77218 deselected in 39.23s ====================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121048
Approved by: https://github.com/zou3519
2024-03-05 20:19:02 +00:00
drisspg
2e6c08a14b Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-03-04 17:36:22 +00:00
PyTorch MergeBot
1458f1de66 Revert "Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)"
This reverts commit 4b7a521856.

Reverted https://github.com/pytorch/pytorch/pull/118935 on behalf of https://github.com/atalman due to Significantly increases build time. Optimization is needed ([comment](https://github.com/pytorch/pytorch/pull/118935#issuecomment-1971723284))
2024-02-29 18:42:21 +00:00
drisspg
4b7a521856 Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-02-28 19:31:15 +00:00
Eddie Yan
702e82da28 [cuDNN][Flash Attention] Minor cleanup for cuDNN SDPA (#120750)
Cleaning up before hopefully starting work on backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120750
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-02-28 17:32:07 +00:00
Eddie Yan
cd380c794f [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-02-14 22:02:06 +00:00
atalman
244b124bb8 Add linux cpu test for 3.12 (#117853)
This is continuation of work: https://github.com/pytorch/pytorch/pull/113987

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117853
Approved by: https://github.com/albanD
2024-02-14 20:52:23 +00:00
CaoE
dfdbd73360 add Half support for flash attention (#119247)
Re-open for https://github.com/pytorch/pytorch/pull/118368.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119247
Approved by: https://github.com/drisspg, https://github.com/malfet
2024-02-07 05:57:41 +00:00
CK Luk
2ad3599a71 Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST (#118979)
Summary: Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST

Test Plan: See the one in D53154041
Reviewed By: yjhao, yanboliang, Yuzhen11

Differential Revision: D53154041

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118979
Approved by: https://github.com/yanboliang
2024-02-06 16:25:33 +00:00
Catherine Lee
f481835115 Revert "add Half support for flash attention on CPU (#118368)" (#119204)
This reverts commit a5a63db3bf.

Fixes #ISSUE_NUMBER

Reverts #118368

Got reverted internally but branch got deleted to automation didn't work

Mildly edited stack trace
```

...
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 635, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "torch/fx/experimental/proxy_tensor.py", line 995, in trace
    res = super().trace(root, concrete_args)
  File "torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/fx/_symbolic_trace.py", line 793, in trace
    (self.create_arg(fn(*args)),),
  File "torch/fx/experimental/proxy_tensor.py", line 665, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 357, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 68, in inner_fn
    outs = fn(*args)
  File "torch/_functorch/_aot_autograd/utils.py", line 161, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 618, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 593, in run_node
    result = super().run_node(n)
  File "torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "torch/fx/interpreter.py", line 274, in call_function
    return target(*args, **kwargs)
  File "torch/_ops.py", line 571, in __call__
    return self_._op(*args, **kwargs)
  File "torch/_subclasses/functional_tensor.py", line 380, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 744, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 779, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 423, in proxy_call
    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1225, in maybe_handle_decomp
    return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
  File "torch/_decomp/decompositions.py", line 4322, in scaled_dot_product_flash_attention_for_cpu
    torch._check(
  File "torch/__init__.py", line 1133, in _check
    _check_with(RuntimeError, cond, message)
  File "torch/__init__.py", line 1116, in _check_with
    raise error_type(message_evaluated)
RuntimeError: query must be FP32, FP64, BF16 but got torch.float16

While executing %_scaled_dot_product_flash_attention_for_cpu : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default](args = (%l_q_, %l_k_, %l_v_), kwargs = {attn_mask: %l_attn_mask_})
Original traceback:
  File "executorch/backends/xnnpack/partition/graphs/sdpa.py", line 34, in forward
    return torch.nn.functional.scaled_dot_product_attention(
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119204
Approved by: https://github.com/kit1980
2024-02-05 18:24:53 +00:00
CaoE
a5a63db3bf add Half support for flash attention on CPU (#118368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118368
Approved by: https://github.com/jgong5, https://github.com/Valentine233, https://github.com/drisspg
ghstack dependencies: #118367
2024-02-02 01:08:39 +00:00
drisspg
126c1621ce Add Support for CausalBias to torch compile (#116071)
Fixes #115363

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116071
Approved by: https://github.com/mlazos
2024-01-30 02:22:48 +00:00
Wei Wang
80cb6db90d [CUDA] [CI] Disable flash attention for sm87 architecture when the head dim > 192 (#117678)
Head dim > 192 requires A100/H100 (sm80 or sm90) per TORCH_CHECK [here](0c26565d5d/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp (L760)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117678
Approved by: https://github.com/eqy, https://github.com/malfet
2024-01-27 01:22:47 +00:00
drisspg
4e29f01bf2 Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)
# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
2024-01-24 22:28:04 +00:00
PyTorch MergeBot
2f84a9d37c Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)"
This reverts commit 5aa92b5090.

Reverted https://github.com/pytorch/pytorch/pull/115663 on behalf of https://github.com/PaliC due to Unfortunately, this pr breaks cuda builds internally ([comment](https://github.com/pytorch/pytorch/pull/115663#issuecomment-1899388813))
2024-01-18 23:40:30 +00:00
Eddie Yan
5aa92b5090 [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-01-18 01:20:36 +00:00
Sun, Jiayi
d9b265adaf modify the conditions as PythonModuleVariable (#116856)
## Motivation
The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error:

**torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext**

I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116856
Approved by: https://github.com/jansel
2024-01-15 11:10:57 +00:00
drisspg
19e93b85b9 Fixes last_dim stride check for singleton dimensions (#117001)
Fixes #116333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117001
Approved by: https://github.com/cpuhrsch
2024-01-10 04:46:49 +00:00
Valentine233
20c2ec9a15 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-07 04:58:23 +00:00
PyTorch MergeBot
2ccc7af028 Revert "[CPU] Add flash attention mask version (#115913)"
This reverts commit 76a3fbb709.

Reverted https://github.com/pytorch/pytorch/pull/115913 on behalf of https://github.com/zou3519 due to broke transformer test on dynamo shard ([comment](https://github.com/pytorch/pytorch/pull/115913#issuecomment-1878043389))
2024-01-05 02:39:12 +00:00
Valentine233
76a3fbb709 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-05 01:27:36 +00:00
Mikayla Gawarecki
d0cf2182ea Fix TransformerEncoderLayer for bias=False (#116760)
Fixes https://github.com/pytorch/pytorch/issues/116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`

`bias=False` was not something that `torch._transformer_encoder_layer_fwd`  was meant to work with, it was my bad that this wasn't tested as I approved https://github.com/pytorch/pytorch/pull/101687.

`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](a17de2d645/torch/nn/modules/transformer.py (L663-L677)) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate

Let me know if these approaches are preferable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116760
Approved by: https://github.com/jbschlosser
2024-01-05 00:13:10 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Mikayla Gawarecki
0f6f582c0d Add config to disable TransformerEncoder/MHA fastpath (#112212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112212
Approved by: https://github.com/jbschlosser
2024-01-02 23:59:30 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
drisspg
1e834e0e50 Fix bug in mem_eff kernel with attention mask and MQA (#116234)
# Summary

Found using the repros mentioned in this issue: #112577

After many go rounds with compute-sanitizer and eventual printf debugging I feel pretty confident that this was the underlying issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116234
Approved by: https://github.com/malfet, https://github.com/danthe3rd, https://github.com/atalman
2023-12-21 21:52:21 +00:00
drisspg
65d3dde665 Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-21 01:56:38 +00:00
PyTorch MergeBot
af8a50e656 Revert "Fix allowed dtypes for mem_eff attention (#116026)"
This reverts commit fc58909bab.

Reverted https://github.com/pytorch/pytorch/pull/116026 on behalf of https://github.com/jeanschmidt due to breaking internal windows buck builds, check internal diff for more details ([comment](https://github.com/pytorch/pytorch/pull/116026#issuecomment-1864354665))
2023-12-20 12:01:34 +00:00
drisspg
fc58909bab Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-18 23:20:52 +00:00
Jeff Daily
e3aefe2970 Revert "Initial Flash Attention support on ROCM (#114309)" (#115975)
This reverts commit 5bddbed399.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115975
Approved by: https://github.com/atalman, https://github.com/malfet
2023-12-16 03:40:14 +00:00
Xinya Zhang
5bddbed399
Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes https://github.com/pytorch/pytorch/issues/112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
2023-12-14 08:52:57 -08:00
Fuzzkatt
661c1cf2aa numerical mismatch fix for test_mem_efficient_attention_attn_mask_vs_math_ref_grads in test_transformers.py (#115707)
adjust dropout_fudge_factor since previous fudge factor was too small and led to numerical mismatch in NVIDIA internal CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115707
Approved by: https://github.com/drisspg
2023-12-14 01:04:39 +00:00
Valentine233
064846dbc2 [cpu] flash attention optimization (#115151)
### Modifications
- **EXP**: Add a fast version with a reduced accuracy (ULP20) to vec exp `exp_u20` and use it in flash attention.
- **FUSION**: Do fusion for `softmax` ops.
- **SCALE**: Move the calculation of `scaling_factor` after `gemm`.

### Performance
_Model: Stable Diffusion V2.1_

| Version | BF16 Kernel latency (s) | BF16 speedup | FP32 Kernel latency (s) | FP32 speedup |
| ----- | ----- | ----- | ----- | ----- |
| PT | 15.865 |  | 35.362 |  |
| PT + EXP | 12.518 | 21.10% | 19.327 | 45.35% |
| PT + EXP + FUSION | 11.774 | 25.79% | 18.306 | 48.23% |
| PT + EXP + FUSION + SCALE | 11.053 | 30.33% | 18.360 | 48.08% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115151
Approved by: https://github.com/jgong5, https://github.com/drisspg
2023-12-12 01:09:55 +00:00
drisspg
d4c79a3078 Add an attention bias subclass for a lower right causal masking (#114823)
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.

The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.

The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.

Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
|  Type   |      Speedup       | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim |     dtype      | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 |            |           |           |           |           |                |          |
|   Max   | 1.831672915579016  |    128     |    32     |   1024    |   2048    |   2048    | torch.bfloat16 |    64    |
|   Min   | 0.9430534166730135 |     1      |    16     |    256    |    416    |   2048    | torch.bfloat16 |   128    |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
2023-12-06 08:29:26 +00:00
drisspg
8556a09d44 Require less alignment for attn bias (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-28 02:40:41 +00:00
PyTorch MergeBot
88a8a0daa4 Revert "Require less alignment for masking (#114173)"
This reverts commit f882c175d8.

Reverted https://github.com/pytorch/pytorch/pull/114173 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing some inductor tests f882c175d8 ([comment](https://github.com/pytorch/pytorch/pull/114173#issuecomment-1823552362))
2023-11-22 21:49:31 +00:00
drisspg
f882c175d8 Require less alignment for masking (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-22 20:02:51 +00:00
drisspg
9b0f2f8d94 expose sdpa helpers to python (#110496)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110496
Approved by: https://github.com/jbschlosser
2023-11-15 07:34:34 +00:00
drisspg
14811d69d7 [BE] Cleanup sdpa test helper usage (#113294)
# Summary

standardizes usage of the rand_sdpa_tensor helper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113294
Approved by: https://github.com/soulitzer
2023-11-09 01:16:53 +00:00
drisspg
e509b162ed Disable FlashAttenion for is_causal=True when seqlen q not equal kv (#111007)
# Summary:
This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2.

### Why are doing this
  // FlashAttention 2 updated the default mask meaning for causal in this PR:
  // 9e5e8bc91e it is now aligned to lower_right which would be a BC break
  // for non-square masks. We will not support non-square masks for causal w/ FAV2

 For more context see:
 https://github.com/pytorch/pytorch/issues/108108

 ### Followup
 A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : https://github.com/pytorch/pytorch/issues/110681

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111007
Approved by: https://github.com/cpuhrsch
2023-10-23 20:33:37 +00:00
drisspg
5183760ca5 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-10-10 18:08:17 +00:00
Fuzzkatt
c28bb46445 Fix test_mem_efficient_attention_vs_math_ref_grads tolerance from test_transformers.py (#108094)
Tolerance currently too low, triggering test failures via numerical mismatch in NVIDIA internal testing for certain H100, A16, A40 configs. cc: @ptrblck @eqy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108094
Approved by: https://github.com/eqy, https://github.com/msaroufim
2023-10-02 20:42:57 +00:00
PyTorch MergeBot
8d6479725a Revert "Adding Backward Support for NestedTensors and FlashAttention (#97485)"
This reverts commit 28d69d5256.

Reverted https://github.com/pytorch/pytorch/pull/97485 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but one of the tests test_fused_kernels_nested_broadcasting_requires_grad_failure_cuda is failing on Windows CUDA f7ba3e85e2 ([comment](https://github.com/pytorch/pytorch/pull/97485#issuecomment-1743474468))
2023-10-02 17:48:57 +00:00
drisspg
28d69d5256 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-09-29 21:34:47 +00:00
drisspg
ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00
Huy Do
a9c663c269 Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 07:43:04 +00:00
PyTorch MergeBot
e45b290127 Revert "Revert "Flash Attention v2 (#105602)" (#108827)"
This reverts commit 24e9bbe22a.

Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
2023-09-08 03:25:45 +00:00
Huy Do
24e9bbe22a Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 02:54:20 +00:00
Michael Gschwind
2a40fe2dbf [experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE (#108672)
Summary:
[experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE -- alternative implementation to D48997976 using preexisting PYTORCH_TESTING_DEVICE_EXCEPT_FOR facility and building remaining logic (for assert-positive listers like test_transformers)  on top of that.

Goal: save ~100 GPU (10% of capacity), enables us to fund more aggressive PyPer unit testing on GPU RE

Test Plan: sandcastle, github

Differential Revision: D48998582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108672
Approved by: https://github.com/bertmaher
2023-09-06 23:33:18 +00:00
drisspg
add45aea1c Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-01 22:14:44 +00:00
PyTorch MergeBot
d569e506ab Revert "Flash Attention v2 (#105602)"
This reverts commit 9df3d882c8.

Reverted https://github.com/pytorch/pytorch/pull/105602 on behalf of https://github.com/huydhn due to I think we miss a case here for sm80 build on inductor workflow as it is now OOM on trunk https://github.com/pytorch/pytorch/actions/runs/6042843139 ([comment](https://github.com/pytorch/pytorch/pull/105602#issuecomment-1701974862))
2023-09-01 01:15:01 +00:00
drisspg
9df3d882c8 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-08-31 16:02:20 +00:00
drisspg
42d60d012e Bias overflow fix mem eff bias (#107968)
Fixes #107959
This should have been fixed here https://github.com/pytorch/pytorch/pull/103201
Edit:
Looking at git blame it appears the dropout revet squashed the changes from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107968
Approved by: https://github.com/cpuhrsch
2023-08-26 00:00:49 +00:00
Mikayla Gawarecki
48b1208e05 Disable nn.MHA fastpath for floating point masks (#107641)
Fixes https://github.com/pytorch/pytorch/issues/107084 by disabling the fast path when floating point masks (which should be additive) are passed

- [We claim in our docs for MHA that float masks will be added to the attention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (be it `key_padding_mask` or `attn_mask`)
- We always canonicalize any mask at the start of MHA in python by converting it to float
- my understanding from Driss is that SDPA properly supports additive masking (but there are many special cases for mask shape for MHA that don't work properly currently (BxT, TxT) so [we're turning this off for now](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L531-L532)
- More broadly, the problem isn't with the SDPA path, but that things are broken for the path it falls back to
-  Right now mha "fast path" code with non-None masks is always going through [this path ](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L554-L640) that  has a call to `masked_softmax` that [converts the masks back to bool](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L154-L156)
- the implication here is that **additive floating point attn_mask and additive key_padding_mask to nn.MHA fastpath are broken**
- This wasn't broken for the user in [https://github.com/pytorch/pytorch/issues/107084](https://l.workplace.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fissues%2F107084&h=AT35qHIQavtxKtriTkrkPsWRB3eSRh4qH5PQUyiTzrPTshoztPL0593AmKCmSdEQ5O-5wib0Fd4mwztVu4YbMWb2ghZnZw1pvpJb9-FYWjDsPQ6_oHRVPzFfj8xYXC1TaFnJCkMYjrGXkIfzzxZvmcQYNnIPgsJSiWgjIw) in 1.13.1 because of [this check which bypassed the fast path if attn_mask was defined](https://github.com/pytorch/pytorch/blob/v1.13.1/torch/nn/modules/activation.py#L1096-L1097) (as Driss pointed out though additive key_padding_mask with the fast path were probably  broken in 1.13.1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107641
Approved by: https://github.com/drisspg, https://github.com/jbschlosser
2023-08-23 15:08:18 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Liao, Xuan
71632d4d24 [cpu] add sdpa choice and UT (#105131)
Feature RFC: https://github.com/pytorch/rfcs/pull/56.

Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen.

## Performance of the stack

### NanoGPT's SDPA kernel
Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket.
Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64.
Machine: SPR.

| Dtype    | Causal   | Mode      | SDPA            | Time (ms per iter) | Speedup |
| -------- | -------- | -------   | -------         | -------            | ------- |
| float32  | FALSE    | Inference | Unfused         | 3.081              |         |
|          |          |           | Flash attention | 1.665              | **1.85045** |
| float32  | TRUE     | Inference | Unfused         | 3.463              |         |
|          |          |           | Flash attention | 1.662              | **2.083634**|
| bfloat16 | FALSE    | Inference | Unfused         | 1.203              |         |
|          |          |           | Flash attention | 1.154              | **1.042461**|
| bfloat16 | TRUE     | Inference | Unfused         | 1.543              |         |
|          |          |           | Flash attention | 1.154              | **1.337088**|
| float32  | FALSE    | Training  | Unfused         | 54.938             |         |
|          |          |           | Flash attention | 23.029             | **2.385601**|
| float32  | TRUE     | Training  | Unfused         | 58.266             |         |
|          |          |           | Flash attention | 17.835             | **3.266947**|
| bfloat16 | FALSE    | Training  | Unfused         | 18.924             |         |
|          |          |           | Flash attention | 18.886             | **1.002012**|
| bfloat16 | TRUE     | Training  | Unfused         | 21.08              |         |
|          |          |           | Flash attention | 14.172             | **1.48744** |

### Stable Diffusion
Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md).
Mode: Inference; Machine: SPR.

| Dtype    | SDPA                    | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup |
| -------- | --------                | -------          | -------      | -------         | ------- |
| float32  | Unfused                 | 1.63             |              | 1139            |         |
|          | Flash attention         | 1.983            | 1.216564     | 547.488         | **2.080411**|
| bfloat16 | Flash attention in IPEX | 4.784            |              | 429.051         |         |
|          | Flash attention         | 4.857            | 1.015259     | 408.823         | **1.049479**|

### LLM models of Torchbench

Dtype: float32; Mode: Inference, single socket; Machine: CPX.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024**
hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841**
hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636**
llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451**

Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324**
hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842**
hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355**
llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105131
Approved by: https://github.com/drisspg
ghstack dependencies: #104583, #104584, #103826, #104693, #104863, #107128
2023-08-20 08:56:21 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Fuzzkatt
3c7331742a test_fused_sdp_choice in test_transformers.py fix (#106587)
sdp dispatcher prioritizes flash attention over efficient attention: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L684-L687, and flash attention is enabled for sm75+: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L625. Thus, the unit test `test_fused_sdp_choice` from `test_transformers.py` which is failing on T4 (sm75) should have this `SM80OrLater` check changed to `SM75OrLater`: https://github.com/pytorch/pytorch/blob/main/test/test_transformers.py#L1914-L1917.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106587
Approved by: https://github.com/drisspg
2023-08-04 03:43:56 +00:00
drisspg
cfa4edcde0 [SDPA] Update dispatch checks to catch last_dim_stride != 1. Also update mask padding logic (#106102)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at bb1fc29</samp>

This pull request simplifies and refactors the code for fused scaled dot product attention kernels in `attention.cu` and `sdp_utils.cpp`, and adds new input validation checks and tests. It also modifies the `sdp_params` struct to store optional mask tensors directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106102
Approved by: https://github.com/cpuhrsch
2023-08-01 19:13:01 +00:00
XiaobingSuper
55f9359d36 fix sdpa math accuracy issue when scale is negative (#105202)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105202
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/drisspg
2023-08-01 00:19:14 +00:00
Fuzzkatt
1cebfef8a4 sm90 efficient attention test fixes (#105978)
Fixes the following two test cases involving efficient attention on sm90:

Explanations:

functorch/test_ops.py: test_vjp_nn_functional_scaled_dot_product_attention_cuda_float32
* originally the test had xfail for all sm
* in https://github.com/pytorch/pytorch/issues/102029, we found that it was unexpectedly passing on sm90
* I made https://github.com/pytorch/pytorch/pull/102131 to update the test to let it pass
* @drisspg seems to have made changes to the behavior such that the original xfail was getting triggered (https://github.com/pytorch/pytorch/issues/102029#issuecomment-1560071148)
* the CI began complaining about the failure again: https://github.com/pytorch/pytorch/issues/102663
* I'm now reverting https://github.com/pytorch/pytorch/pull/102131 to bring back the original xfail now that the behavior has been fixed by @drisspg to trigger the xfail in sm90 similar to all other sm

test_transformers.py: test_mem_efficient_fail_sm90_cuda
* the test as it's currently written seems to expect the sdp dispatcher to fail for mem efficient attention on sm90; however, testing this on H100, it actually succeeds, so I'm disabling the test for now as the current expected result may be outdated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105978
Approved by: https://github.com/eqy, https://github.com/kshitij12345, https://github.com/zou3519
2023-07-31 17:59:40 +00:00
drisspg
cb9a4fbbf2 [BE] Improve test_transformers test structure (#105938)
# Summary

We have a vast majority of test that only run on cuda. Decorating with @onlycuda causes pytest to instantiate 2x the tests and skip half of them. This overhead is non trivial when the #tests cross larger like it has for this file.

This breaks up the cuda only tests into a separate class
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105938
Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
2023-07-26 22:16:20 +00:00
drisspg
c4b7311fc2 Meff Attn Bias (#104310)
# Summary

### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big.  At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099

### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention.  I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
    # Run a bunch of experiments
    batch_sizes = [8, 16, 32]
    num_heads = [16, 32]
    max_seq_lens = [15, 64, 128, 512, 555, 1024]
    embed_dims = [32, 64, 128]
    dtypes = [torch.float16, torch.bfloat16, torch.float32]
    pad_percentages = [None]
    backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    run_backward = True
    attn_mask = True
```

   The function calls `sdpa(input**).sum().backward()`.

   I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
   `Geomean Speedup: 1.977`

An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:
![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff)

 This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.

The full data can be found here:

[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
2023-07-26 15:51:59 +00:00
PyTorch MergeBot
340ec1f460 Revert "Meff Attn Bias (#104310)"
This reverts commit 5453508115.

Reverted https://github.com/pytorch/pytorch/pull/104310 on behalf of https://github.com/DanilBaibak due to PR introduced cuda OOM issue ([comment](https://github.com/pytorch/pytorch/pull/104310#issuecomment-1650171538))
2023-07-25 16:37:32 +00:00
drisspg
5453508115 Meff Attn Bias (#104310)
# Summary

### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big.  At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099

### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention.  I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
    # Run a bunch of experiments
    batch_sizes = [8, 16, 32]
    num_heads = [16, 32]
    max_seq_lens = [15, 64, 128, 512, 555, 1024]
    embed_dims = [32, 64, 128]
    dtypes = [torch.float16, torch.bfloat16, torch.float32]
    pad_percentages = [None]
    backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    run_backward = True
    attn_mask = True
```

   The function calls `sdpa(input**).sum().backward()`.

   I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
   `Geomean Speedup: 1.977`

An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:
![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff)

 This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.

The full data can be found here:

[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
2023-07-24 22:19:26 +00:00
Michael Gschwind
11b753af01 Refactor causal mask generation and detection for nn.transformer (#105265)
Summary:
* Create a private global-scope function _generate_subsequent because static class attribute member functions not supported by TorchScript resulting in torchscripting errors.
* Make TransformerEncoder and TransformerDecoder consistent w.r.t. is_causal handling by calling _detect_casual_mask
* Clarify documentation that is_causal is a hint
* Move causal mask detection into a method _detect_causal_mask
* only accept input-size compatible causal mask as causal mask
* update _generate_subsequent_causal_mask to include factory kwargs for dtype and device:
   avoid extra copies & conversions by passing directly to torch.full.

Test Plan: sandcastle & github CICD
Continuation of #101487 (due to a tooling issue) which is a continuation-in-part of https://github.com/pytorch/pytorch/pull/98327 by @janEbert

Differential Revision: D47427117

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105265
Approved by: https://github.com/mikaylagawarecki
2023-07-19 01:26:50 +00:00
Michael Gschwind
07a1c3f7ff Exercise subclass of TransformerEncoderLayer (#105297)
Summary: Exercise subclass of TransformerEncoderLayer
Additional unit tests for change in #102045 to show correct e2e operation (cf. issue #100188)

Also: remove batch_first from list of TS module constants where it is not used to resolve torchscripting warning

Test Plan: saqndcastle, github

Differential Revision: D47503004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105297
Approved by: https://github.com/davidberard98
2023-07-17 16:03:10 +00:00
Driss Guessous
4a008d268a REDO of dropout support for mem eff #102038 (#103704)
THIS IS A new PR with the changes from #102038 + #103201 +  plus namespacing changes to fix bug.

# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103704
Approved by: https://github.com/cpuhrsch
2023-06-26 23:05:03 +00:00
Yinghai Lu
4c3799447f
Back out "Dropout support for memory efficient attention (#102038)" & "Two small mem_eff bug fixes (#103201)" (#103464)
Summary:
Original commit changeset: 04c4473d8510

Original Phabricator Diff: D46584152 & D46582033

Test Plan: Already explained in summary.

Reviewed By: yinghai

Differential Revision: D46633283

fbshipit-source-id: c23c2945408988f3c4339dfd5cd40ae46261716c

Co-authored-by: Shenxiu Liu <shenxiu@meta.com>
2023-06-12 18:56:48 -07:00
Driss Guessous
606fb882c4 Dropout support for memory efficient attention (#102038)
# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102038
Approved by: https://github.com/cpuhrsch
2023-06-08 21:50:12 +00:00
Driss Guessous
2800a04a17 Add device range helper and remove sm86 specific check for memory efficient attention (#102985)
# Summary
Since we have upstreamed the latest changes of memory efficient attetnion we can remove the sm86/sm89 specific check. All head_sizes (assuming correctly alignment) should work for sm86 and sm89 size and don't have a max capability.

If head_size > 96 there will be a big drop in performance but should not error and still maintain memory savings by not materializing attention weights.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102985
Approved by: https://github.com/cpuhrsch
2023-06-07 00:28:40 +00:00
Michael Gschwind
4d89489df5 Move static checks of layers[0] (e.g., isinstance check) to model build time (#102045)
Summary: Move static checks of layers[0] (e.g., isinstance check) to model build time because isinstance() does not work for torchscripted code.  Because the validation is now performed while constructing the object, the isinstance() call is performed in eager mode at model build time, and we avoid needing to call  isinstance() at runtime to determine whether the layers in a model are an instance of the TransformerEncoderLayer class, or its derived classes.

Test Plan: sandcastle, github

Differential Revision: D46096222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102045
Approved by: https://github.com/mikaylagawarecki
2023-05-30 19:42:01 +00:00
Driss Guessous
ef13fde290 Increase mem eff backward performance (#101847)
# Summary
This is another upstream which is much smaller than the previous.
This bumps the kernel versions from  xformers
Current: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](6425fd0cac)
With this PR: [1d635e193e169fc677b2e7fa42dad7ebe88eec9e](1d635e193e)

### Notable Changes:
- Drastically improve the BW pass in multiple cases (especially when B*numHeads < 100)
- H100 Support: *Warning* While these kernels have been added, we don't have the CI/CD machines to test.
- Enables a deterministic mode.

## Specific Changes
- Updates to the backward kernel.
- Added num_splits_key which we hard code to -1. (This is a another performance knob that we set to the heuristic)
- Update gen_code and kernels to produce h100 instantiations.

### Due Diligence Checks:
* CUDA_lib size: No changes in size

#### Peformance
* Micro Benchmark: (batch_size: 1, num_heads=25, seq_len=4096, embed_dim = 64 | grid:[1,25,1]block: [128,1,1])
    * MemEfficientAttention Backward Kernel: 27.972 ms
    * After the updated Xformers code(https://github.com/pytorch/pytorch/pull/100583): 23.958 ms
    * With this PR: 4.085 ms
* Ran micro benchmarks on sdpa_forw().sum().backward() over a range of dtypes, and input shapes
   * Geo_mean increase -> 1.17x
   * Max increase -> 2.95x
   * min_increase -> 0.8x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101847
Approved by: https://github.com/cpuhrsch
2023-05-26 02:25:31 +00:00
Aaron Gokaslan
3e2ea32dab [BE]: Enable ruff rule TRY302 and apply fixes (#101874)
Removes useless try statements and unreachable code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874
Approved by: https://github.com/malfet
2023-05-19 17:30:52 +00:00
Driss Guessous
c9ba967c21 Upstream xformers code (#100583)
# Summary
Since the initial upstream of memory efficient attention from xformers: #86157, significant work updates have been made to the kernel including - increased performance, bug-fixes, and added functionality. This PR upstreams the latest version of this kernel as of: version 0.0.20 or commit: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](6425fd0cac)

## Future
Although this version of the Kernel has support for dropout and arbitrary attention bias, I did not add this support to SDPA yet, and left the guards in sdp_utils. Those will follow up PRs in order to reduce the scope creep of these substantial changes, and ensure that nothing is broken.

## Specific Changes
### Minor Changes
* The build system work was done in the previous PR and so no changes were needed to CMAKE 🤞
* Adding the new files and re-arranging/creating folder structure
* Updating include paths
* Switching from xformer specific functions: `XFORMERS_CHECK -> TORCH_CHECK`
* Changes to xformer specific macros
* Updates to the `generate_kernels.py` to use account for Pytorch file structure, also added an arg parse that I could run on a test dir before creating the files in place.

### Bigger Changes
* Previous Kernel changes "Removed the chunk optimization: see discussion here: https://github.com/pytorch/pytorch/pull/96880"
* Increased the number of cuda kernels -> potentially effecting the cuda_lib size.
* Preemptively made changes to the dtypes of seed and offset in order to allow for cuda_graphs: #100196 this is not finished.
* Made VERY BC breaking changes to at::_efficient_attention_forward and at::_efficeint_attention_backward function signatures.
    * I made these changes due to in part to the ability for this PR to land:https://github.com/pytorch/pytorch/pull/100196

### Due Diligence Checks:
* CUDA_lib size:
    * Before: 496 MiB
    * After:    496MiB
* Performance Sweep:
    * I sweeped over 576 configs for forward only inference and the geomean speedup was 0.98x with a min speed up of 0.84 and a max speedup of 1.2
    * For Forw+Back running on 270 configs ( to reduce memory) the geomean speedup was 1.02X with a min speed up of 1.02 and a max speedup of 1.35.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100583
Approved by: https://github.com/cpuhrsch
2023-05-18 16:15:34 +00:00
Driss Guessous
52363de2ec Clean up grad check in sdp_utils.h (#101435)
# Summary
The priorty order was not being run correctly because of confusing function name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101435
Approved by: https://github.com/jbschlosser
2023-05-16 02:22:45 +00:00
Elias Ellison
7e333fe502 Fix cuda graphs & sdpa for dropout==0 (#101280)
Fixes cuda graph failures from https://github.com/pytorch/pytorch/pull/100931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101280
Approved by: https://github.com/ngimel
2023-05-12 19:06:45 +00:00
Natalia Gimelshein
bfe5f5bbe1 [WIP] enable cuda graphs support for flash attention with dropout (#100196)
Fixes #99905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100196
Approved by: https://github.com/drisspg
2023-05-08 16:19:18 +00:00
Michael Gschwind
e5b065525b Add unit test for nested_tensor input to nn.TransformerEncoder (#100650)
Summary: Add unit test for nested_tensor input & fix

Test Plan: sandcastle

Differential Revision: D45580393

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100650
Approved by: https://github.com/jbschlosser
2023-05-05 23:34:14 +00:00
Driss Guessous
2892c06e82 Ensure device arg is passed to test_transformers (#100260)
# Summary
Follow up to #100121 to actually make sure that test functions are accepting a device arg as input.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100260
Approved by: https://github.com/malfet, https://github.com/ngimel
2023-05-04 01:36:06 +00:00
PyTorch MergeBot
c3aa59c8f5 Revert "[WIP] enable cuda graphs support for flash attention with dropout (#100196)"
This reverts commit 32615618e4.

Reverted https://github.com/pytorch/pytorch/pull/100196 on behalf of https://github.com/clee2000 due to broke no ops build 32615618e4 https://github.com/pytorch/pytorch/actions/runs/4866578063/jobs/8678258318 ([comment](https://github.com/pytorch/pytorch/pull/100196#issuecomment-1532352810))
2023-05-03 01:41:56 +00:00
Michael Gschwind
8430430e94 Handle trailing masked column behavior for nested tensor (#100113)
Summary:
Handle trailing masked column behavior for nested tensor by padding during to_padded, to original tensor size

https://github.com/pytorch/pytorch/issues/97111

Test Plan: sandcastle & github

Differential Revision: D45167874

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100113
Approved by: https://github.com/bertmaher, https://github.com/cpuhrsch, https://github.com/drisspg
2023-05-03 00:30:17 +00:00
Natalia Gimelshein
32615618e4 [WIP] enable cuda graphs support for flash attention with dropout (#100196)
Fixes #99905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100196
Approved by: https://github.com/drisspg
2023-05-02 23:05:31 +00:00
Fuzzkatt
06bf5d4de7 enable headdims > 64 for flash attention on sm90 (#99776)
Follow up to #99105 which disabled FlashAttention when using autograd and mem eff attention for the following cases

head_dim > 64
sm86 or newer

We have tested enabling FlashAttention on sm90 and it works, so this PR will enable it back for sm90 and add in a test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99776
Approved by: https://github.com/malfet, https://github.com/drisspg
2023-05-02 19:11:48 +00:00
Driss Guessous
b8d7a28e1a refactor test_sdpa into two test classes to account for failure modes (#100121)
### Summary
This PR creates a new TestSDPAFailureModes test class in order to better seperate what each test is trying to do.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100121
Approved by: https://github.com/malfet, https://github.com/ngimel
2023-04-27 21:42:40 +00:00
Michael Gschwind
36e1ae6778 De-select odd numbered heads from nn.MHA fastpath (#99672)
Summary:
https://github.com/pytorch/pytorch/issues/97128

* Add test for mha num_heads %2 != 0
* Fix test
* Add test for bias false
* show test passes

Test Plan: sandcastle

Differential Revision: D45161767

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99672
Approved by: https://github.com/ngimel
2023-04-25 00:27:18 +00:00
Horace He
547bef11ee tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.

The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```

TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy.  My new heuristic achieves 94% accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-04-21 23:28:44 +00:00
Christian Puhrsch
fdeee43650 Disable SDPA FlashAttention backward and mem eff attention on sm86+ for head_dim above 64 (#99105)
Expand sdpa_utils.h check to disable FlashAttention when using autograd and mem eff attention for the following cases
- head_dim > 64
- sm86 or newer

Previously we only disable these kernels on sm86 and for head_dim equal to 128.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99105
Approved by: https://github.com/malfet
2023-04-21 01:00:15 +00:00
PyTorch MergeBot
5cb788a9a5 Revert "[cuda rng] Making offset calculation independent of device properties (#98988)"
This reverts commit 26f318574f.

Reverted https://github.com/pytorch/pytorch/pull/98988 on behalf of https://github.com/anijain2305 due to Diagnosing if sebotnet has flakiness
2023-04-19 17:23:40 +00:00
Animesh Jain
26f318574f [cuda rng] Making offset calculation independent of device properties (#98988)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98988
Approved by: https://github.com/ngimel
2023-04-19 01:35:44 +00:00
Lucas Pasqualin
35c6547f02 Adds 3D attn_mask support to merge_masks() for Multihead Attention fast path (#98991)
Fixes #97409

Adds support for 3D attn_mask by always expanding attn_mask to 4D as per https://github.com/pytorch/pytorch/pull/98375#issuecomment-1499504721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98991
Approved by: https://github.com/jbschlosser
2023-04-13 20:29:57 +00:00
Michael Gschwind
c757647dd8 [Better Transformer] make is_causal a hint and force attn_mask to be set on is_causal=True in F.MHA (#97214)
Summary:
This fixes an issue raised in [is_causal parameter in torch.nn.TransformerEncoderLayer.forward does not work #96941](https://github.com/pytorch/pytorch/issues/96941) where results computed with is_causal do not properly reflect causal masking.

In PyTorch 2.0, Accelerated PT Transformers added the is_causal parameter to legacy nn.Transformer* and nn.MHA APIs aligned with and intended to engage the is_causal parameter of the new scaled_dot_product_attention (SDPA) operator.

At present is_causal works differently for Transformer* modules, the nn.MHA and F.MHA:
* The nn.Transformer* modules treat is_causal as an optional indicator about the format of attn_mask. This is because some layers (such as the CLIP layer use the attention mask in the layer, and thus the attn_mask was a required feature.)
* Initially, nn.MHA and F.MHA were defined to align with F.SDPA in behavior: a user may specify either the attention mask, or is_causal, but not both.  It seemed to make sense at the time to align SDPA and MHA, esp since there was a larger overlap of parameters which have since changed, e.g., with the removal of need_weights from SDPA. (See below for why this makes sense.)

Unfortunately, this does not work because of how MHA was changed to handle the need_weights parameter.  When need_weights is present, we do not (any more) call SDPA because support for need_weights was removed from SDPA before the release.  The rationale is that need_weights defeats all optimization at the foundation of SDPA performance.  Having the flag might thus mislead users into thinking they get good performance and have them disappointed when they enable a legacy feature of MHA which massively degrades performance.  (They might not think anything of enabling that, because it is on by default in MHA today, which leads to more  issues.)

Since SDPA does not (no longer) support need_weights, we need to pick a separate path which implements attention using a set of discrete operations that allocates a tensor for weights.  Alas, this code path does not have support for is_causal, because attention is implemented as matmul and using the attention mask.  Thus, is_causal has no impact.  (A substantially similar situation arises with how kpm is implemented today because Nested Tensors are not supported by torch.compile() in 2.0)

This problem was masked because all uses of legacy nn.MHA (and F.MHA) come through nn.Transformer* which called self-attention (i.e., nn.MHA) only ever with the attention mask attn_mask, and never with is_causal, a missed optimization opportunit that would have been addressed in a future performance update.

Regrettably, always calling nn.MHA with attn_mask prevented diagnosing of the issue of not having a suitable attention mask when need_weights support was dropped from SDPA and a discrete implementation of attention was added for that scenario, and for the execution path with key_padding_mask.

We have two options to address this issue:

Solution 1: Whenever nn.MHA and F.MHA are executed with is_causal set, we internally create a causal mask at significant expense of allocating a tensor and filling it with a triangular causal matrix.  This increases memory usage, and runtime, for allocating a causal mask.  To add insult to injury, in all current (and likely future) execution scenarios, MHA is called by a model using the nn.Transformer API which already has that matrix and passes it from nn.module to nn.module.  Then the passing in of attn_mask has to be suppressed by nn.TransformerEncoderLayer, only for nn.MHA to immediately allocate the very same tensor again to satisfy the requirement to have an attention mask for the computation. (We expect new use cases to use SDPA directly.)

Solution 2: We align the behavior of nn.MHA and F.MHA with the rest of the existing nn.Transformer API, and require the attention mask to be passed into nn.MHA in addition to is_causal as an optional indicator about the nature of the attention mask rather than as an alternative to attn_mask.  Then, when we choose the code path for processing MHA with need_weights or a key_padding_mask, we have the attn_mask passed down through the nn.Transformer* hierarchy, without the added overhead of allocating an attention mask as in scenario 1.

This PR implements solution 2 which offers better performance and in retrospect aligns MHA better with the rest of the Transformer modules as the definition of SDPA evolved into a more streamlined high-performance operator.  It ostensibly changes how is_causal works, by requiring the attention mask to be specified.  However, as described here, and as shown in the submitted issue, is_causal is not working as intended today, so it requires a change regardless.

In that sense, a change in API does not occur per-se, as the current implementation is not working, and a change has to occur either way to resolve the submitted issue, breaking any use cases that depend on the current implementation.  Checks exist (and more can be added) that flag any scenarios where is_causal is passed as True, but no attention mask is provided, ensuring that there's not quiet change from even the faulty behavior present in 2.0.

As  an upside, the present implementation will improve performance by addressing the passing of the is_causal flag from Transformer modules to MHA, speeding up training for these examples, e.g., finetuning BERT, RoBERTa, XLM-R models.

Differential Revision: D44245725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97214
Approved by: https://github.com/albanD
2023-03-25 01:36:30 +00:00
Peter Bell
4e054175d6 Fix uniform returning end point for BFloat16 and Half (#96962)
Fixes #96947

If we generate `1.0 - float_eps`, the BFloat16 and Half constructors will round this to 1.0 which is outside of the half-open range. Instead, we delay the bounds change until after the value has been rounded.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96962
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-03-21 14:01:29 +00:00
Michael Gschwind
92eb9d363a Decoder native functions join the dead code society (#96025)
Summary: Decoder native joins the dead code society

With the recent introduction of PT2, we no longer need native decoder operators:
1 - full-function SDPA kernels can be used to implement cross-attention efficiently without the (slower) decoder MHA blob.
2 - torch.compile() generates more efficient code across many platforms from the python implementation of decoders than the decoder layer blob by tailoring code to target

Test Plan: github & sandcastle

Differential Revision: D43811808

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96025
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-03-17 09:45:55 +00:00
Michael Gschwind
61cb544397 Align mask formatting of both masks more closely (#96286)
Summary: Align mask formatting of both masks more closely

Test Plan: sandcastle & github

Differential Revision: D43878634

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96286
Approved by: https://github.com/cpuhrsch
2023-03-11 02:18:05 +00:00
Driss Guessous
11aab72dc9 [SDPA] Add an optional scale kwarg (#95259)
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: #94729 and has been asked for by a couple of users.

# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
2023-03-08 18:07:40 +00:00