Mikayla Gawarecki
f49d5e30eb
Change owners of test/test_transformers.py to module: multi-headed-attention ( #132519 )
...
So flaky tests get tagged with `module: multi-headed-attention` instead of `module: nn`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132519
Approved by: https://github.com/Skylion007
2024-08-02 20:12:33 +00:00
PyTorch MergeBot
bcb4f7c172
Revert "Grouped Query Attention ( #128898 )"
...
This reverts commit 6b28af1b79 .
Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038 ))
2024-08-02 18:58:46 +00:00
PyTorch MergeBot
59b73079a0
Revert "Always use high precision for SDPA math backend ( #128922 )"
...
This reverts commit fbf3bc0a60 .
Reverted https://github.com/pytorch/pytorch/pull/128922 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR has a dependency on another PR (https://github.com/pytorch/pytorch/pull/128898 ) that has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/128922#issuecomment-2265949958 ))
2024-08-02 18:46:50 +00:00
Jianyu Huang
fbf3bc0a60
Always use high precision for SDPA math backend ( #128922 )
...
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.
Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.
Differential Revision: D58710805
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell , https://github.com/drisspg
2024-08-01 18:55:48 +00:00
jainapurva
6b28af1b79
Grouped Query Attention ( #128898 )
...
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458 **
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
PyTorch MergeBot
499ead96ff
Revert "Grouped Query Attention ( #128898 )"
...
This reverts commit d039b14207 .
Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481 ))
2024-07-30 13:11:24 +00:00
jainapurva
d039b14207
Grouped Query Attention ( #128898 )
...
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458 **
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00
Yan Zhiwei
2a02b5cd22
[Intel GPU] Dispatch Stub support ( #130019 )
...
# Motivation
Structured codegen is beneficial for easier decoupling tensor meta setting and kernel implementation. At present, XPU operators need to handle tensor metas in hand-written way.
We plan to leverage the codegen system for auto generate structured operators. This PR facilitate the `DispatchStub` support for Intel GPUs. Based on that, XPU operators would have possibility to register kernel functor to operator stubs.
This is a prerequisite of PR #130082 , where we will modify the codegen system to generate XPU needed source files and headers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130019
Approved by: https://github.com/EikanWang , https://github.com/gujinghui , https://github.com/albanD
2024-07-29 02:18:52 +00:00
drisspg
1bfe7eb7e6
Update how we do sdpa testing ( #131743 )
...
## Motivation
This refactor aligns our testing methodology with the Flash Attention upstream repository while addressing several key issues:
1. **Standardized comparison**: We now compare fused kernels against float64 references, using the maximum of a calculated tolerance (based on same-precision math implementation) or standard float32 `atol`.
2. **Reduced redundancy**: Utilizing the same tensors for both same-precision math and fused kernel runs eliminates duplication.
3. **Improved maintainability**: The new approach simplifies tolerance adjustments across all affected tests.
4. **Consistency**: Standardizing tensor comparisons ensures a more uniform and reliable testing suite.
These changes collectively simplify our testing code, improve its maintainability, and provide a more robust framework for validating our attention mechanisms.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131743
Approved by: https://github.com/jainapurva , https://github.com/jbschlosser
2024-07-27 03:58:49 +00:00
Valentine233
868d9a4f12
[cpu][flash attention] fix nan issue ( #130014 )
...
Fixes #127055 .
NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130014
Approved by: https://github.com/jgong5 , https://github.com/drisspg
2024-07-10 02:33:26 +00:00
eqy
86fb76e871
[SDPA] Clean up print in test/test_transformers.py ( #130302 )
...
Left this in #125343 , oops...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130302
Approved by: https://github.com/awgu
2024-07-09 09:20:52 +00:00
Huy Do
8f70bf7a94
Skip TestSDPAPrivateUse1Only on FBCODE ( #129997 )
...
Summary: The test is from D59181111, but I couldn't figure out a way to make it pass on FBCODE because loading PyTorch C++ extension requires Ninja which is not going to work with BUCK
Test Plan: `buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:transformers`
Differential Revision: D59304327
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129997
Approved by: https://github.com/drisspg
2024-07-03 06:48:51 +00:00
eqy
24b6c5a41f
[cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere ( #129587 )
...
Fix for #129579
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129587
Approved by: https://github.com/Skylion007 , https://github.com/drisspg
2024-06-30 19:37:44 +00:00
eqy
f845a7a91a
[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 ( #125343 )
...
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-30 19:22:16 +00:00
FEI
59e4e92556
sdp::SDPBackend::flash_attention support PrivateUse1 ( #126392 )
...
Fixes https://github.com/pytorch/pytorch/issues/124271
cc @cpuhrsch @drisspg @albanD @soulitzer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126392
Approved by: https://github.com/drisspg
2024-06-28 17:48:40 +00:00
PyTorch MergeBot
999eec8dea
Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 ( #125343 )"
...
This reverts commit b7e7a4cb01 .
Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some test_transformer running on internal A100 and V100 ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196202003 ))
2024-06-28 06:03:54 +00:00
PyTorch MergeBot
d21993bbb8
Revert "[cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere ( #129587 )"
...
This reverts commit 7854d84acb .
Reverted https://github.com/pytorch/pytorch/pull/129587 on behalf of https://github.com/huydhn due to Sorry for revert yet another of your change but I need to revert this to cleanly revert https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196187332 ([comment](https://github.com/pytorch/pytorch/pull/129587#issuecomment-2196198756 ))
2024-06-28 06:01:07 +00:00
eqy
7854d84acb
[cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere ( #129587 )
...
Fix for #129579
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129587
Approved by: https://github.com/Skylion007 , https://github.com/drisspg
2024-06-28 04:42:45 +00:00
Eddie Yan
b7e7a4cb01
[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 ( #125343 )
...
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-26 00:49:18 +00:00
iibrahimli
2db33054b3
Disable fast path in TransformerEncoderLayer when there are forward (pre-)hooks attached to modules ( #128415 )
...
Fixes #128413
Disable fast-path if there are forward hooks or pre-hooks.
Example failure case given in the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128415
Approved by: https://github.com/mikaylagawarecki
2024-06-21 17:38:08 +00:00
Valentine233
5da428d9eb
[cpu][flash attention] fix attention mask issue ( #128816 )
...
For attention mask in flash attention:
- Fix the issue of accessing illegal memory when the last size of mask is 1.
- Add UT of attention mask for various shapes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128816
Approved by: https://github.com/jgong5 , https://github.com/drisspg
2024-06-21 01:12:48 +00:00
PyTorch MergeBot
817ce6835b
Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 ( #125343 )"
...
This reverts commit 4c971932e8 .
Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2163690162 ))
2024-06-12 18:47:52 +00:00
PyTorch MergeBot
7db501ba2b
Revert "[cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA ( #128350 )"
...
This reverts commit 45dccfddcd .
Reverted https://github.com/pytorch/pytorch/pull/128350 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/128350#issuecomment-2163669538 ))
2024-06-12 18:35:18 +00:00
eqy
45dccfddcd
[cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA ( #128350 )
...
CC @vedaanta-nvidia @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128350
Approved by: https://github.com/Skylion007
2024-06-11 19:22:21 +00:00
eqy
4c971932e8
[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 ( #125343 )
...
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-09 06:53:34 +00:00
Xinya Zhang
d34075e0bd
Add Efficient Attention support on ROCM ( #124885 )
...
This patch implements `with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):` by reusing AOTriton's accelerated SDPA implementation
Known limitations:
- Only supports MI200/MI300X GPUs
- Does not support varlen
- Does not support `CausalVariant`
- Optional arguments `causal_diagonal` and `seqlen_k` in `_efficient_attention_forward/backward` must be null
- Does not work well with inductor's SDPA rewriter. The rewriter has been updated to only use math and flash attention on ROCM.
This PR also uses a different approach of installing AOTriton binary instead of building it from source in the base docker image. More details on motivation: https://github.com/pytorch/pytorch/pull/124885#issuecomment-2153229129
`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_transformers.py` yields "55028 passed, 20784 skipped" results with this change. [Previous result](https://hud.pytorch.org/pr/127528 ) of `test_transformers.py` was 0 error, 0 failure, 55229 skipped out of 75517 tests in total (the XML report does not contain total number of passed tests).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124885
Approved by: https://github.com/malfet
2024-06-08 22:41:05 +00:00
PyTorch MergeBot
a9309502af
Revert "Refactoring to remove unused variable ( #125252 )"
...
This reverts commit b094622bc9 .
Reverted https://github.com/pytorch/pytorch/pull/125252 on behalf of https://github.com/drisspg due to going to land codev ([comment](https://github.com/pytorch/pytorch/pull/125252#issuecomment-2089394606 ))
2024-05-02 01:49:57 +00:00
Apurva Jain
b094622bc9
Refactoring to remove unused variable ( #125252 )
...
Summary: Removed unused variable for running encoder
Test Plan: buck test //caffe2/test:transformers
Differential Revision: D56771972
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125252
Approved by: https://github.com/drisspg
2024-05-01 15:17:45 +00:00
Fuzzkatt
1cf62e86a4
skip various unit tests for Jetson ( #122531 )
...
skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531
Approved by: https://github.com/eqy , https://github.com/malfet
2024-04-16 01:26:26 +00:00
Aaron Gokaslan
1d6c5972c1
[BE]: Optimize min/max/sum comprehensions C419 ( #123960 )
...
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
William Wen
cbde0f048b
[dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support ( #123300 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123300
Approved by: https://github.com/jansel , https://github.com/malfet , https://github.com/zou3519
2024-04-05 20:13:17 +00:00
Xinya Zhang
b83c94339e
Fix performance regression and memory storage handling of Flash Attention on ROCM ( #122857 )
...
This PR fixes the two major issues that was discovered after the initial merge of PR #121561
1. The Flash Attention support added by has severe performance regressions on regular shapes (power of two head dimensions and sequence lengths) compared with PR #115981 . Its performance is worse than the math backend and only has numerical stability advantages. This PR fixes this problem.
2. There is a flaw of memory storage handling in PR #121561 which does not copy the gradients back to the designated output tensor. This PR removes the deprecated `TensorStorageSanitizer` class which is unnecessary due to the more flexible backward kernel shipped by PR #121561
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122857
Approved by: https://github.com/jeffdaily , https://github.com/drisspg
2024-03-29 16:37:24 +00:00
Xinya Zhang
12116aee68
Add Flash Attention support on ROCM ( #121561 )
...
This patch addresses the major limitations in our previous [PR #115981 ](https://github.com/pytorch/pytorch/pull/115981 ) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton )
- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
* MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
* Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
* varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
* Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
* Kernel is selected according to autotune information from Triton.
Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API
This is a more extensive fix to #112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/huydhn
2024-03-28 00:27:38 +00:00
FEI
e08cbc0d41
update comment of test_invalid_last_dim_stride in test_transformers.py ( #122679 )
...
Fixes #122594
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122679
Approved by: https://github.com/mikaylagawarecki
2024-03-26 15:40:24 +00:00
PyTorch MergeBot
764eae9c4e
Revert "Add Flash Attention support on ROCM ( #121561 )"
...
This reverts commit a37e22de70 .
Reverted https://github.com/pytorch/pytorch/pull/121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm. We are working to reland this change before 2.3 release ([comment](https://github.com/pytorch/pytorch/pull/121561#issuecomment-2007717091 ))
2024-03-19 17:14:28 +00:00
drisspg
42624bceb6
Fixes nan with large bf16 values ( #122135 )
...
Fixes #121558
Performance on main:
``` Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.608132004970683 | 65.90210803551601 |
| 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.75877740024589 | 64.83824399765581 |
| 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 16.465420153690506 | 67.6770955324173 |
| 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 17.398148600477725 | 68.19829455344006 |
| 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 29.053532000398263 | 99.58901099162175 |
| 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 27.826815698063 | 98.05690299253911 |
| 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 49.89655229728669 | 178.24282555375248 |
| 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 48.840098950313404 | 174.5950729819015 |
| 1 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 505.66218036692584 | 1865.9265094902366 |
| 1 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 295.0534054543823 | 967.3831606050952 |
| 1 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.496030446141958 | 55.11070846114308 |
| 1 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.47399884648621 | 55.452342028729625 |
| 1 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 13.216444296995178 | 55.14447903260589 |
| 1 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 12.763233599252999 | 55.142355500720434 |
| 1 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 19.409965351223946 | 74.9107634765096 |
| 1 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 19.02470579952933 | 74.84168506925926 |
| 1 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 46.37695319834165 | 172.19150450546294 |
| 1 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 45.225963747361675 | 185.19691249821335 |
| 1 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 634.3090848531574 | 2249.057865119539 |
| 1 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 320.47313248040155 | 1053.0515247955916 |
| 4 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 13.448987301671878 | 63.63581650657579 |
| 4 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.509283400140703 | 63.059300999157124 |
| 4 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 19.71098779467866 | 105.55780201684684 |
| 4 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 18.264925852417946 | 105.12311349157244 |
| 4 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 45.218703348655254 | 222.87272597895935 |
| 4 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 43.55393464793451 | 230.63290398567915 |
| 4 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 134.02968645095825 | 514.6893998607993 |
| 4 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 157.13709802366793 | 624.5892751030624 |
| 4 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 1776.7079547047617 | 6353.551096981391 |
| 4 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1143.6000745743513 | 3811.8767354171723 |
| 4 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.717129248427227 | 55.35991647047922 |
| 4 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.746983398916198 | 55.76716404175386 |
| 4 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 17.255573300644752 | 106.47456656442955 |
| 4 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 16.46409669774584 | 108.07770595420152 |
| 4 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 46.63354124641045 | 213.74862996162847 |
| 4 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 47.01801469782367 | 240.78139301855117 |
| 4 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 127.76448752265424 | 508.08745552785695 |
| 4 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 168.6308984644711 | 667.2996102133766 |
| 4 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 2268.1598202325404 | 7727.2648515645415 |
| 4 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1242.8469699807465 | 4161.965740495361 |
| 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 14.340955897932872 | 93.72280450770633 |
| 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 13.25262250029482 | 93.2030284893699 |
| 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 27.598425600444898 | 183.23776399483904 |
| 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 26.362583553418514 | 183.51862096460536 |
| 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 84.52303148806094 | 383.50319798337296 |
| 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 89.41743348259479 | 432.5502900755964 |
| 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 217.76640450116247 | 943.9354750793427 |
| 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 303.0781910638325 | 1225.4394043702632 |
| 8 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 3470.8542854059488 | 12194.579601055011 |
| 8 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2268.1174043100327 | 7608.0941944383085 |
| 8 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.289720651460811 | 95.88620596332476 |
| 8 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.618648946750909 | 95.56685149436818 |
| 8 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 31.567946751601994 | 180.62468653079122 |
| 8 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 28.611703700153157 | 189.4215695792809 |
| 8 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 84.11306998459621 | 385.25596749968827 |
| 8 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 93.82540901424363 | 455.77428903197875 |
| 8 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 226.80530551588163 | 965.8026450779289 |
| 8 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 327.4116570246406 | 1312.5067745568228 |
| 8 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 4445.5064804060385 | 15020.768146496266 |
| 8 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2433.0302356975153 | 8300.016750581563 |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```
Performance on this branch:
```Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.783618393586949 | 65.59692794689909 |
| 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.064015300711617 | 56.99719698168337 |
| 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 16.629025398287922 | 68.65267595276237 |
| 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 17.462356004398313 | 68.35797848179936 |
| 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 29.5476081490051 | 101.22994752600789 |
| 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 28.395320149138573 | 98.62275794148445 |
| 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 50.50016101449728 | 181.4357690163888 |
| 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 49.450615647947416 | 175.86063902126625 |
| 1 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 506.06461532879626 | 1866.0613044630736 |
| 1 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 299.9336270149797 | 976.4662646921353 |
| 1 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.45752210286446 | 58.79682704107836 |
| 1 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.407129396684468 | 58.14061599085107 |
| 1 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 13.822759891627355 | 56.56979401828722 |
| 1 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 13.39154909946956 | 56.7130644340068 |
| 1 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 20.282494352431968 | 77.29688903782517 |
| 1 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 19.899454596452415 | 75.4446149803698 |
| 1 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 48.494275606935844 | 177.5322465109639 |
| 1 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 46.84524350450374 | 189.1778860008344 |
| 1 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 635.1026654010639 | 2248.0451600858937 |
| 1 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 335.1591735263355 | 1080.4320796160027 |
| 4 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 13.63953539985232 | 65.50709309522063 |
| 4 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.858113402035087 | 63.021871959790595 |
| 4 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 19.98318645055406 | 105.87883047992364 |
| 4 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 18.619045056402683 | 104.90188701078296 |
| 4 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 45.91175540117546 | 226.00732848513871 |
| 4 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 44.39614630537107 | 232.39317198749632 |
| 4 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 135.5409600073472 | 522.7949097752571 |
| 4 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 158.79383607534692 | 628.5856699105352 |
| 4 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 1775.9978299727663 | 6343.203847063706 |
| 4 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1160.680354805663 | 3842.235009651631 |
| 4 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.553713708417488 | 65.50691701704638 |
| 4 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.486379051348194 | 56.9980075233616 |
| 4 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 17.56585600087419 | 107.89892700267956 |
| 4 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 16.828144202008843 | 109.05519902007653 |
| 4 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 48.23235589428805 | 217.8974545095116 |
| 4 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 49.09284680034033 | 244.73925953498107 |
| 4 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 134.77827049791813 | 522.7259948151186 |
| 4 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 176.60772847011688 | 681.5171707421541 |
| 4 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 2267.821540008299 | 7720.425300067291 |
| 4 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1295.3941145678982 | 4272.425139788538 |
| 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 14.514714101096615 | 94.2192979855463 |
| 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 13.553097198018804 | 93.244242540095 |
| 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 27.95821905019693 | 185.0469880155288 |
| 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 26.709681446664035 | 184.22623950755226 |
| 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 85.85420495364815 | 388.3417735341937 |
| 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 89.97473795898259 | 434.4228169647977 |
| 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 220.6919804448262 | 958.9654899900779 |
| 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 306.55586952343583 | 1233.2170095760375 |
| 8 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 3470.7326447824016 | 12183.611298678443 |
| 8 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2299.064100370742 | 7669.618452200666 |
| 8 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.427107692928985 | 96.96270158747211 |
| 8 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.856995843118057 | 96.38117247959599 |
| 8 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 32.9956392000895 | 182.52741603646427 |
| 8 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 29.397601098753512 | 191.0755339777097 |
| 8 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 89.06024845782667 | 392.2585004474967 |
| 8 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 97.78487798757851 | 462.07307645818213 |
| 8 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 240.521906001959 | 992.4693452194335 |
| 8 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 341.98952303268015 | 1339.2950996058062 |
| 8 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 4445.311005110853 | 15001.030603889374 |
| 8 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2535.9767401823774 | 8528.990152990447 |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```
```
{'avg_forward_time_nan_fix': 399.7900972732653,
'avg_backward_time_nan_fix': 1409.652114014413,
'avg_forward_time_main_branch': 394.6807206988645,
'avg_backward_time_main_branch': 1399.4055472857629,
'geo_mean_nan_fix': 150.95049601244946,
'geo_mean_main_branch': 148.3381648508822}
```
The y axis is wrong and is micro seconds but the relative comparison still works
<img width="790" alt="Screenshot 2024-03-18 at 3 34 15 PM" src="https://github.com/pytorch/pytorch/assets/32754868/ca278c15-b815-4535-bdcd-07e522055466 ">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122135
Approved by: https://github.com/cpuhrsch
2024-03-19 16:32:00 +00:00
Xinya Zhang
a37e22de70
Add Flash Attention support on ROCM ( #121561 )
...
This patch addresses the major limitations in our previous [PR #115981 ](https://github.com/pytorch/pytorch/pull/115981 ) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton )
- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
* MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
* Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
* varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
* Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
* Kernel is selected according to autotune information from Triton.
Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API
This is a more extensive fix to #112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet , https://github.com/atalman
2024-03-12 01:16:53 +00:00
y-sq
393b4ab432
Fixes issue_119785 ( #121048 )
...
Fixes #ISSUE_119785
- Removed all sentinel files of `test_causal_variants_.*`.
- The `test_causal_variants_causal_variant_` tests could pass after removing the dynamo_skips files.
- The `test_causal_variants_compile_causal_variant` fails with `PYTORCH_TEST_WITH_DYNAMO=1`. These tests already call torch.compile, so added @skipIfTorchDynamo to skip them for `PYTORCH_TEST_WITH_DYNAMO`.
**Tests**
```
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.7745s] [ 3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.8020s] [ 6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0385s] (Lower righ...) [ 9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.5046s] [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.6483s] [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.8537s] [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.8388s] [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.4859s] [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu SKIPPED [0.0084s] (Th...) [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu SKIPPED [0.0086s] (Th...) [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0081s] (Th...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu SKIPPED [0.0085s] (Th...) [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu SKIPPED [0.0082s] (Thi...) [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu SKIPPED [0.0085s] (Thi...) [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu SKIPPED [0.0081s] (Thi...) [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu SKIPPED [0.0085s] (Thi...) [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.4185s] [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.4273s] [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0280s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [8.0999s] [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.3785s] [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.3818s] [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.3864s] [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.7668s] [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda SKIPPED [0.0089s] (...) [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda SKIPPED [0.0087s] (...) [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0087s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda SKIPPED [0.0084s] (...) [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda SKIPPED [0.0087s] (T...) [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda SKIPPED [0.0087s] (T...) [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda SKIPPED [0.0084s] (T...) [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda SKIPPED [0.0087s] (T...) [100%]
=================================================== 14 passed, 18 skipped, 77218 deselected in 39.72s ===================================================
```
```
$ pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.2410s] [ 3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.3984s] [ 6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lower righ...) [ 9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.0095s] [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.1749s] [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.2138s] [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.2715s] [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0108s] [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.4864s] [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.5346s] [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lo...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.1722s] [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.2341s] [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.4786s] [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.4635s] [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0861s] [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.7579s] [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.0044s] [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0007s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [9.2065s] [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0081s] [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0063s] [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0059s] [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.0055s] [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [0.1200s] [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.1032s] [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0010s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [0.1151s] [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0705s] [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0713s] [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0696s] [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.1516s] [100%]
=================================================== 28 passed, 4 skipped, 77218 deselected in 39.23s ====================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121048
Approved by: https://github.com/zou3519
2024-03-05 20:19:02 +00:00
drisspg
2e6c08a14b
Update flash_attention kernel from 2.3.6 to 2.5.5 ( #118935 )
...
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6 ) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5 ).
The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-03-04 17:36:22 +00:00
PyTorch MergeBot
1458f1de66
Revert "Update flash_attention kernel from 2.3.6 to 2.5.5 ( #118935 )"
...
This reverts commit 4b7a521856 .
Reverted https://github.com/pytorch/pytorch/pull/118935 on behalf of https://github.com/atalman due to Significantly increases build time. Optimization is needed ([comment](https://github.com/pytorch/pytorch/pull/118935#issuecomment-1971723284 ))
2024-02-29 18:42:21 +00:00
drisspg
4b7a521856
Update flash_attention kernel from 2.3.6 to 2.5.5 ( #118935 )
...
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6 ) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5 ).
The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-02-28 19:31:15 +00:00
Eddie Yan
702e82da28
[cuDNN][Flash Attention] Minor cleanup for cuDNN SDPA ( #120750 )
...
Cleaning up before hopefully starting work on backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120750
Approved by: https://github.com/Skylion007 , https://github.com/drisspg
2024-02-28 17:32:07 +00:00
Eddie Yan
cd380c794f
[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference ( #115663 )
...
#113713
Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.
CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-02-14 22:02:06 +00:00
atalman
244b124bb8
Add linux cpu test for 3.12 ( #117853 )
...
This is continuation of work: https://github.com/pytorch/pytorch/pull/113987
Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117853
Approved by: https://github.com/albanD
2024-02-14 20:52:23 +00:00
CaoE
dfdbd73360
add Half support for flash attention ( #119247 )
...
Re-open for https://github.com/pytorch/pytorch/pull/118368 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119247
Approved by: https://github.com/drisspg , https://github.com/malfet
2024-02-07 05:57:41 +00:00
CK Luk
2ad3599a71
Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST ( #118979 )
...
Summary: Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST
Test Plan: See the one in D53154041
Reviewed By: yjhao, yanboliang, Yuzhen11
Differential Revision: D53154041
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118979
Approved by: https://github.com/yanboliang
2024-02-06 16:25:33 +00:00
Catherine Lee
f481835115
Revert "add Half support for flash attention on CPU ( #118368 )" ( #119204 )
...
This reverts commit a5a63db3bf .
Fixes #ISSUE_NUMBER
Reverts #118368
Got reverted internally but branch got deleted to automation didn't work
Mildly edited stack trace
```
...
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 635, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "torch/fx/experimental/proxy_tensor.py", line 995, in trace
res = super().trace(root, concrete_args)
File "torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/fx/_symbolic_trace.py", line 793, in trace
(self.create_arg(fn(*args)),),
File "torch/fx/experimental/proxy_tensor.py", line 665, in wrapped
out = f(*tensors)
File "<string>", line 1, in <lambda>
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 357, in _functionalized_f_helper
f_outs = fn(*f_args)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 68, in inner_fn
outs = fn(*args)
File "torch/_functorch/_aot_autograd/utils.py", line 161, in flat_fn
tree_out = fn(*args, **kwargs)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 618, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "torch/fx/interpreter.py", line 145, in run
self.env[node] = self.run_node(node)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 593, in run_node
result = super().run_node(n)
File "torch/fx/interpreter.py", line 202, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "torch/fx/interpreter.py", line 274, in call_function
return target(*args, **kwargs)
File "torch/_ops.py", line 571, in __call__
return self_._op(*args, **kwargs)
File "torch/_subclasses/functional_tensor.py", line 380, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 744, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 779, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 423, in proxy_call
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 1225, in maybe_handle_decomp
return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
File "torch/_decomp/decompositions.py", line 4322, in scaled_dot_product_flash_attention_for_cpu
torch._check(
File "torch/__init__.py", line 1133, in _check
_check_with(RuntimeError, cond, message)
File "torch/__init__.py", line 1116, in _check_with
raise error_type(message_evaluated)
RuntimeError: query must be FP32, FP64, BF16 but got torch.float16
While executing %_scaled_dot_product_flash_attention_for_cpu : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default](args = (%l_q_, %l_k_, %l_v_), kwargs = {attn_mask: %l_attn_mask_})
Original traceback:
File "executorch/backends/xnnpack/partition/graphs/sdpa.py", line 34, in forward
return torch.nn.functional.scaled_dot_product_attention(
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119204
Approved by: https://github.com/kit1980
2024-02-05 18:24:53 +00:00
CaoE
a5a63db3bf
add Half support for flash attention on CPU ( #118368 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118368
Approved by: https://github.com/jgong5 , https://github.com/Valentine233 , https://github.com/drisspg
ghstack dependencies: #118367
2024-02-02 01:08:39 +00:00
drisspg
126c1621ce
Add Support for CausalBias to torch compile ( #116071 )
...
Fixes #115363
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116071
Approved by: https://github.com/mlazos
2024-01-30 02:22:48 +00:00
Wei Wang
80cb6db90d
[CUDA] [CI] Disable flash attention for sm87 architecture when the head dim > 192 ( #117678 )
...
Head dim > 192 requires A100/H100 (sm80 or sm90) per TORCH_CHECK [here](0c26565d5d/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp (L760) ).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117678
Approved by: https://github.com/eqy , https://github.com/malfet
2024-01-27 01:22:47 +00:00