mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
b46eb1ccaf
55 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
b46eb1ccaf |
[dynamo] control one_graph behavior additionally through config (#154283)
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283 Approved by: https://github.com/jansel, https://github.com/anijain2305 |
||
|
|
ae53510b9e |
Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test Update: this was not as straightforward as I originally imagined. Context below. **TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not. **Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`. Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`. (1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests: |
||
|
|
12cb11a268 |
[Inductor UT] Refactor FlexAttention UT and add CPU tests (#144953)
This PR extends and refines all rest UTs for CPU and more devices in `test/inductor/test_flex_attention.py` and `test/inductor/test_flex_decoding.py`, as a follow-up to https://github.com/pytorch/pytorch/pull/141453 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144953 Approved by: https://github.com/drisspg |
||
|
|
5e9f792479 |
[ROCm] Unskip flex attention UTs after triton 3.3 bump (#148327)
Enable `test_flex_attention.py::TestLearnableBiases` unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148327 Approved by: https://github.com/jeffdaily |
||
|
|
98b3f1db9f |
[Flex Attention] support num_heads > 1 in block_mask (#148857)
Previously flex decoding errors when block mask has num_heads > 1. So users have to use num_heads=1, or explicitly mark `kernel_options={"FORCE_USE_FLEX_ATTENTION": True}`.
This PR fixes this issue. When not using grouped query attention (GQA, i.e., Hq == Hkv), we support block mask with num_heads = 1 and num_heads = num_query_heads (i.e., Hq). This is the same setting as flex attention kernel.
When using GQA (i.e., Hq != Hkv), we support block mask with num_heads = 1. When num_heads = Hq, we fall back to flex attention kernel so user don't need to explicitly mark `kernel_options={"FORCE_USE_FLEX_ATTENTION": True}` anymore.
Why fallback? In the current flex decoding triton kernel, grouped query heads for the same kv head are handled by the same thread block. Supporting num_heads = Hq with GQA requires support different kv num blocks for different query heads in the same thread block, leading to lots of redundant workload. So we should better use the main flex_attention kernel where each query head is handled by a separate block.
Fixes #148527
Fixes #147267
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148857
Approved by: https://github.com/drisspg
|
||
|
|
ba9ed856e0 |
[FlexAttention] Improve error msg for embedding < 16 (#147765)
flex_attention uses tl.dot, which [does not support embedding < 16](https://github.com/triton-lang/triton/issues/2266) on input shapes. This PR adds explicit error message for users who are prototyping with small tensors. Fixes #147701 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147765 Approved by: https://github.com/drisspg |
||
|
|
c6707734de |
Enable non power of 2 head_dim for FlexAttention (#133495)
# Summary - Adds support for non-power of 2 headdim by launching blocks w/ head_dim rounded to the next valid power. - Other option I considered was building up the final dot_products with smaller blocks (this would probably work but for sake of code complexity going with this option for now) ### Corollary We had a bug in our backwards kernel where we were using index_k instead of index_v. This should have shown up for the qk_head_dim != v_head_dim cases.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133495 Approved by: https://github.com/Chillee |
||
|
|
99dbc5b0e2 |
PEP585 update - test (#145176)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93 |
||
|
|
577708e6de |
Unskipped multiple inductor tests for ROCm (#143581)
All of them should be fine to run now after the triton fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143581 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
||
|
|
069419569d |
[PagedAttention] Support different input position for each batch index (#144693)
In LLM inference, each request usually has different prefill length, leading to different input position for each batch index. This PR adds such support for paged attention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144693 Approved by: https://github.com/drisspg |
||
|
|
7d9f26de05 |
Revert "Unskipped multiple inductor tests for ROCm (#143581)"
This reverts commit
|
||
|
|
e05d67790e |
Unskipped multiple inductor tests for ROCm (#143581)
All of them should be fine to run now after the triton fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143581 Approved by: https://github.com/jataylo, https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
||
|
|
d8c8ba2440 |
Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD |
||
|
|
ed9931e6ee |
Add tests for non divisible inputs for flex decoding (#143214)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143214 Approved by: https://github.com/drisspg |
||
|
|
424156c26c |
[ROCm] Update to AOTriton 0.8b (#140172)
Notable new features for SDPA operators on AMD systems from AOTriton 0.8b:
1. Nestedtensor support;
2. MQA/GQA support;
3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases;
+ The kernel should use top-left alignment, bottom right alignment will be added later
4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status.
However, users are strongly recommended to update to ROCM 6.2.4, notably for
its firmware updates.
Related unit tests are enabled as well.
Notable related changes from AOTriton 0.8b:
1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`;
2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB
3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency.
+ Should not be a problem, since `lzma` is part of Python Standard Library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
|
||
|
|
af88326250 |
Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)
Fixes https://github.com/pytorch/pytorch/issues/141435 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625 Approved by: https://github.com/drisspg ghstack dependencies: #138788 |
||
|
|
a34a56f69f |
Revert "Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)"
This reverts commit
|
||
|
|
795f28ac55 |
Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)
Fixes https://github.com/pytorch/pytorch/issues/141435 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625 Approved by: https://github.com/drisspg ghstack dependencies: #138788 |
||
|
|
cb71bcc542 |
Replace clone.detach with detach.clone (#140264)
Fixes #64532 As state in issue, replace `clone.detach` by `detach.clone` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264 Approved by: https://github.com/soulitzer |
||
|
|
540f3ef9b1 |
Fix flex_decode to build offsets off of strides (#139516)
Fixes PR: https://github.com/pytorch/pytorch/issues/139462 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139516 Approved by: https://github.com/Chillee |
||
|
|
68134a320e |
[Flex Attention] Paged Attention (#137164)
This PR adds paged attention for flex attention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137164 Approved by: https://github.com/drisspg |
||
|
|
34c18887ad |
[FlexAttention] Remove restriction on QK headdim > V headdim (#135884)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135884 Approved by: https://github.com/Chillee |
||
|
|
28224329ad |
[Flex Attention] fix block size order (#136657)
`create_block_mask` currently gives wrong BLOCK_SIZE and shape when using non-default block size `(128,128)`. This PR fixes the issue by using BLOCK_SIZE order `(Q_BLOCK_SIZE, KV_BLOCK_SIZE)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136657 Approved by: https://github.com/Chillee, https://github.com/drisspg |
||
|
|
f7ab0e9989 |
Revert "[Flex Attention] fix block size order (#136657)"
This reverts commit |
||
|
|
b42f1e3641 |
[Flex Attention] fix block size order (#136657)
`create_block_mask` currently gives wrong BLOCK_SIZE and shape when using non-default block size `(128,128)`. This PR fixes the issue by using BLOCK_SIZE order `(Q_BLOCK_SIZE, KV_BLOCK_SIZE)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136657 Approved by: https://github.com/Chillee, https://github.com/drisspg |
||
|
|
ccca3de0cd |
[ROCm] Enable Flex attention tests on AMD gpus (#136245)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136245 Approved by: https://github.com/malfet |
||
|
|
74fd1bf965 |
[ROCm] Update to AOTriton 0.7b (#134498)
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
+ Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`
Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.
Fixes #133540
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
|
||
|
|
6e13f5eb38 |
[FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension. ## Details Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention. This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward. ## Benchmark GPU: H100 We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`. ``` python benchmarks/transformer/score_mod.py --calculate-bwd ``` ### Perf before this PR **FWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|---------------|------------|----------------|------------------------------| | Average | 0.743 | | | | | | Max | 0.955 | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | | Min | 0.548 | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | **BWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|-------------|------------|----------------|-----------------------------| | Average | 0.834 | | | | | | Max | 1.261 | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | | Min | 0.456 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | <details> <summary> Full performance sweep </summary> | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 107.040 | 140.800 | 0.888 | 0.760 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.840 | 19.744 | 112.576 | 140.064 | 0.802 | 0.804 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.232 | 17.344 | 87.744 | 142.496 | 0.878 | 0.616 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 108.192 | 143.328 | 0.888 | 0.755 | | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.904 | 22.400 | 106.432 | 136.512 | 0.889 | 0.780 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.424 | 26.752 | 91.712 | 106.688 | 0.726 | 0.860 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.808 | 22.432 | 89.024 | 101.920 | 0.883 | 0.873 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.840 | 22.272 | 88.896 | 102.592 | 0.891 | 0.867 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.240 | 32.416 | 116.768 | 112.256 | 0.933 | 1.040 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 29.536 | 37.024 | 113.664 | 102.688 | 0.798 | 1.107 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.656 | 32.800 | 116.992 | 127.008 | 0.935 | 0.921 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.592 | 32.480 | 116.928 | 112.160 | 0.942 | 1.043 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.920 | 198.656 | 204.512 | 0.653 | 0.971 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 37.760 | 62.528 | 189.536 | 170.624 | 0.604 | 1.111 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.896 | 62.368 | 198.304 | 205.824 | 0.656 | 0.963 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.952 | 198.432 | 203.648 | 0.653 | 0.974 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 318.528 | 355.904 | 947.232 | 1162.496 | 0.895 | 0.815 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 199.776 | 252.128 | 677.792 | 813.184 | 0.792 | 0.834 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 316.512 | 363.328 | 947.712 | 1361.984 | 0.871 | 0.696 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 317.984 | 356.864 | 947.264 | 1165.024 | 0.891 | 0.813 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 446.656 | 734.656 | 1664.288 | 2172.960 | 0.608 | 0.766 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 278.688 | 467.648 | 1182.624 | 1339.296 | 0.596 | 0.883 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 447.872 | 744.096 | 1662.944 | 2196.544 | 0.602 | 0.757 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 448.128 | 732.928 | 1663.072 | 2156.800 | 0.611 | 0.771 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.648 | 16.640 | 107.520 | 143.008 | 0.940 | 0.752 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.776 | 18.240 | 129.056 | 141.920 | 0.865 | 0.909 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.168 | 16.640 | 103.616 | 139.648 | 0.912 | 0.742 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.616 | 16.640 | 128.608 | 164.448 | 0.938 | 0.782 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 21.952 | 125.344 | 170.304 | 0.901 | 0.736 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 23.712 | 104.288 | 196.896 | 0.834 | 0.530 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.072 | 21.952 | 102.080 | 177.056 | 0.869 | 0.577 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.648 | 21.920 | 109.920 | 170.848 | 0.896 | 0.643 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.936 | 127.808 | 228.832 | 0.954 | 0.559 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 29.472 | 33.856 | 113.152 | 215.072 | 0.871 | 0.526 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.496 | 32.160 | 116.576 | 231.744 | 0.948 | 0.503 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.904 | 116.320 | 229.824 | 0.955 | 0.506 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.480 | 61.440 | 176.448 | 345.312 | 0.659 | 0.511 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 38.304 | 59.424 | 169.312 | 371.360 | 0.645 | 0.456 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.960 | 61.760 | 176.512 | 358.912 | 0.663 | 0.492 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.352 | 61.696 | 176.512 | 344.928 | 0.654 | 0.512 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.224 | 357.728 | 905.728 | 1668.448 | 0.884 | 0.543 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 199.904 | 248.416 | 636.544 | 1109.088 | 0.805 | 0.574 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 314.880 | 363.616 | 906.304 | 1658.176 | 0.866 | 0.547 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.160 | 354.368 | 906.080 | 1649.024 | 0.892 | 0.549 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.912 | 739.840 | 1555.808 | 2521.952 | 0.604 | 0.617 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 279.776 | 463.904 | 1068.928 | 1849.888 | 0.603 | 0.578 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.080 | 748.960 | 1553.504 | 2629.888 | 0.596 | 0.591 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.208 | 740.608 | 1558.880 | 2524.960 | 0.602 | 0.617 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 33.568 | 41.280 | 170.016 | 147.584 | 0.813 | 1.152 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 30.688 | 43.040 | 159.552 | 146.720 | 0.713 | 1.087 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.112 | 41.504 | 170.112 | 152.672 | 0.822 | 1.114 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.240 | 41.152 | 170.272 | 134.976 | 0.832 | 1.261 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.672 | 76.416 | 295.296 | 263.648 | 0.637 | 1.120 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.088 | 72.576 | 281.920 | 237.664 | 0.621 | 1.186 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.032 | 76.672 | 295.520 | 265.248 | 0.626 | 1.114 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.096 | 76.096 | 295.456 | 262.112 | 0.632 | 1.127 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.920 | 111.232 | 401.568 | 382.944 | 0.844 | 1.049 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 68.192 | 95.232 | 338.752 | 326.816 | 0.716 | 1.037 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.984 | 111.840 | 401.856 | 444.224 | 0.840 | 0.905 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 94.176 | 110.496 | 401.600 | 383.136 | 0.852 | 1.048 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.488 | 227.040 | 727.424 | 739.712 | 0.579 | 0.983 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 95.616 | 169.760 | 616.864 | 574.112 | 0.563 | 1.074 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.680 | 228.672 | 727.616 | 746.048 | 0.576 | 0.975 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.104 | 225.696 | 727.904 | 735.392 | 0.581 | 0.990 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1227.296 | 1386.656 | 3720.192 | 4539.904 | 0.885 | 0.819 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 691.360 | 831.712 | 2515.872 | 3067.808 | 0.831 | 0.820 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1228.192 | 1403.136 | 3715.520 | 5309.280 | 0.875 | 0.700 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1229.024 | 1384.992 | 3715.904 | 4550.368 | 0.887 | 0.817 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1784.832 | 2865.888 | 6539.840 | 8460.224 | 0.623 | 0.773 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1017.408 | 1660.480 | 4369.824 | 5056.992 | 0.613 | 0.864 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1792.448 | 2904.864 | 6546.080 | 8537.024 | 0.617 | 0.767 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1795.552 | 2856.864 | 6544.672 | 8400.160 | 0.629 | 0.779 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.880 | 148.832 | 179.936 | 0.881 | 0.827 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.168 | 38.080 | 138.528 | 167.552 | 0.818 | 0.827 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 39.168 | 148.512 | 181.248 | 0.874 | 0.819 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.784 | 148.864 | 180.224 | 0.883 | 0.826 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.832 | 76.352 | 253.632 | 295.968 | 0.640 | 0.857 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 45.760 | 65.792 | 239.040 | 290.752 | 0.696 | 0.822 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.576 | 253.312 | 304.032 | 0.637 | 0.833 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.192 | 253.600 | 296.096 | 0.640 | 0.856 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.728 | 109.728 | 357.696 | 498.912 | 0.854 | 0.717 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 68.704 | 92.288 | 295.616 | 386.240 | 0.744 | 0.765 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.632 | 111.392 | 357.408 | 512.448 | 0.841 | 0.697 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.280 | 109.952 | 357.696 | 501.440 | 0.848 | 0.713 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.392 | 230.496 | 612.224 | 807.552 | 0.570 | 0.758 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 96.512 | 165.184 | 502.624 | 672.384 | 0.584 | 0.748 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.360 | 232.608 | 612.064 | 832.320 | 0.565 | 0.735 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.008 | 230.528 | 612.640 | 804.320 | 0.568 | 0.762 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1227.968 | 1377.408 | 3477.920 | 5324.384 | 0.892 | 0.653 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 695.264 | 824.544 | 2268.224 | 3210.208 | 0.843 | 0.707 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.640 | 1404.576 | 3476.832 | 5463.456 | 0.875 | 0.636 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.416 | 1378.752 | 3478.048 | 5367.712 | 0.891 | 0.648 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1788.736 | 2867.712 | 6039.520 | 8616.256 | 0.624 | 0.701 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1021.952 | 1653.824 | 3866.208 | 5306.848 | 0.618 | 0.729 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.752 | 2896.352 | 6044.128 | 8871.360 | 0.617 | 0.681 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.080 | 2868.672 | 6040.160 | 8550.144 | 0.623 | 0.706 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.504 | 71.552 | 312.768 | 255.040 | 0.804 | 1.226 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 49.472 | 71.104 | 285.696 | 243.520 | 0.696 | 1.173 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 58.112 | 72.896 | 312.768 | 288.256 | 0.797 | 1.085 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.952 | 71.680 | 312.768 | 255.552 | 0.808 | 1.224 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.336 | 144.256 | 580.128 | 500.160 | 0.571 | 1.160 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.160 | 123.712 | 552.544 | 447.648 | 0.616 | 1.234 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.400 | 145.184 | 580.032 | 504.032 | 0.568 | 1.151 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.368 | 143.904 | 580.192 | 499.936 | 0.572 | 1.161 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.216 | 209.568 | 787.872 | 747.712 | 0.846 | 1.054 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 121.984 | 168.256 | 651.968 | 628.256 | 0.725 | 1.038 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.088 | 211.488 | 788.320 | 864.352 | 0.837 | 0.912 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.440 | 208.576 | 787.424 | 749.120 | 0.851 | 1.051 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.472 | 441.376 | 1405.440 | 1431.648 | 0.565 | 0.982 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 172.960 | 312.064 | 1172.064 | 1096.448 | 0.554 | 1.069 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.632 | 446.336 | 1405.408 | 1448.480 | 0.559 | 0.970 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 250.944 | 440.128 | 1406.624 | 1421.952 | 0.570 | 0.989 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2418.720 | 2747.936 | 7330.432 | 9023.712 | 0.880 | 0.812 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1353.696 | 1608.480 | 4941.696 | 6078.752 | 0.842 | 0.813 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2427.456 | 2746.816 | 7329.792 | 10539.968 | 0.884 | 0.695 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2426.688 | 2763.168 | 7336.256 | 9057.536 | 0.878 | 0.810 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3554.240 | 5634.400 | 12919.872 | 16843.489 | 0.631 | 0.767 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2003.648 | 3250.784 | 8610.144 | 10015.424 | 0.616 | 0.860 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3582.080 | 5710.944 | 12923.328 | 17011.871 | 0.627 | 0.760 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3581.920 | 5618.144 | 12934.528 | 16745.888 | 0.638 | 0.772 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.120 | 71.232 | 269.760 | 295.680 | 0.802 | 0.912 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 49.408 | 65.312 | 242.304 | 253.952 | 0.756 | 0.954 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.504 | 72.544 | 269.632 | 298.976 | 0.793 | 0.902 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.760 | 71.040 | 269.600 | 296.640 | 0.813 | 0.909 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 82.336 | 147.168 | 466.080 | 487.456 | 0.559 | 0.956 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.040 | 435.392 | 453.248 | 0.667 | 0.961 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.856 | 147.424 | 465.920 | 499.552 | 0.555 | 0.933 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.760 | 146.656 | 466.176 | 485.984 | 0.557 | 0.959 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 206.976 | 678.080 | 866.976 | 0.853 | 0.782 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 121.664 | 164.768 | 538.240 | 636.160 | 0.738 | 0.846 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 209.664 | 677.696 | 883.424 | 0.842 | 0.767 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 177.440 | 207.840 | 677.248 | 868.288 | 0.854 | 0.780 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.272 | 449.536 | 1163.424 | 1420.832 | 0.557 | 0.819 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 173.472 | 305.376 | 929.408 | 1104.544 | 0.568 | 0.841 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 249.376 | 454.976 | 1163.648 | 1455.296 | 0.548 | 0.800 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.368 | 450.144 | 1163.520 | 1409.984 | 0.556 | 0.825 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2416.576 | 2726.208 | 6835.520 | 10442.784 | 0.886 | 0.655 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1357.440 | 1590.752 | 4433.664 | 5975.296 | 0.853 | 0.742 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2427.360 | 2747.040 | 6853.056 | 10670.784 | 0.884 | 0.642 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2441.120 | 2718.944 | 6836.640 | 10433.792 | 0.898 | 0.655 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3555.392 | 5620.960 | 11944.000 | 16504.801 | 0.633 | 0.724 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2010.848 | 3241.152 | 7636.064 | 9870.464 | 0.620 | 0.774 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3557.440 | 5688.352 | 11935.744 | 17090.496 | 0.625 | 0.698 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3562.720 | 5630.432 | 11939.168 | 16392.033 | 0.633 | 0.728 | </details> ### Perf after this PR **FWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|---------------|------------|----------------|----------------------------| | Average | 0.776 | | | | | | Max | 1.006 | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | | Min | 0.566 | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | **BWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|-------------|------------|----------------|-----------------------------| | Average | 0.817 | | | | | | Max | 1.150 | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | | Min | 0.454 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | <details> <summary> Full performance sweep </summary> | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.680 | 17.056 | 64.544 | 73.376 | 0.919 | 0.880 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.712 | 19.872 | 65.408 | 72.864 | 0.791 | 0.898 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.160 | 17.280 | 64.896 | 73.888 | 0.935 | 0.878 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.192 | 17.120 | 64.896 | 75.424 | 0.946 | 0.860 | | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.648 | 22.496 | 89.184 | 82.592 | 0.873 | 1.080 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.320 | 26.816 | 91.264 | 82.880 | 0.758 | 1.101 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.096 | 22.528 | 89.184 | 83.776 | 0.892 | 1.065 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.680 | 22.432 | 89.184 | 120.096 | 0.877 | 0.743 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.384 | 32.512 | 119.232 | 128.960 | 0.996 | 0.925 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.176 | 37.248 | 113.664 | 119.520 | 0.810 | 0.951 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.512 | 32.928 | 119.264 | 131.456 | 0.987 | 0.907 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.448 | 32.704 | 119.200 | 128.352 | 0.992 | 0.929 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.952 | 62.176 | 199.040 | 214.304 | 0.675 | 0.929 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 39.744 | 62.880 | 189.504 | 179.968 | 0.632 | 1.053 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.472 | 62.784 | 199.136 | 217.664 | 0.661 | 0.915 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 42.048 | 61.952 | 199.168 | 214.496 | 0.679 | 0.929 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 341.184 | 357.632 | 980.256 | 1328.896 | 0.954 | 0.738 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 212.576 | 252.960 | 673.888 | 824.864 | 0.840 | 0.817 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.000 | 363.296 | 980.768 | 1375.808 | 0.936 | 0.713 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.768 | 356.832 | 980.960 | 1326.272 | 0.955 | 0.740 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 459.392 | 737.120 | 1678.240 | 2205.248 | 0.623 | 0.761 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 292.672 | 468.096 | 1178.016 | 1371.584 | 0.625 | 0.859 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.144 | 745.312 | 1680.000 | 2252.512 | 0.620 | 0.746 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.112 | 736.576 | 1679.008 | 2216.480 | 0.627 | 0.758 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.064 | 16.704 | 105.120 | 120.768 | 0.962 | 0.870 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.552 | 18.144 | 107.136 | 121.696 | 0.857 | 0.880 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.096 | 16.768 | 102.688 | 120.864 | 0.960 | 0.850 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.032 | 16.576 | 104.736 | 124.672 | 0.967 | 0.840 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.392 | 21.952 | 104.736 | 174.656 | 0.883 | 0.600 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 20.128 | 23.712 | 105.216 | 199.008 | 0.849 | 0.529 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.904 | 21.888 | 103.744 | 179.520 | 0.909 | 0.578 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.968 | 21.952 | 104.640 | 177.312 | 0.910 | 0.590 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.096 | 31.904 | 118.720 | 231.968 | 1.006 | 0.512 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.528 | 33.952 | 112.480 | 218.304 | 0.899 | 0.515 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.160 | 32.224 | 118.752 | 237.312 | 0.998 | 0.500 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.128 | 32.032 | 118.240 | 233.120 | 1.003 | 0.507 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.280 | 177.408 | 350.688 | 0.674 | 0.506 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 39.552 | 59.360 | 168.832 | 371.488 | 0.666 | 0.454 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.984 | 61.696 | 177.376 | 360.416 | 0.680 | 0.492 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.760 | 177.184 | 355.744 | 0.669 | 0.498 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.744 | 357.888 | 939.712 | 1665.376 | 0.949 | 0.564 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 212.608 | 248.832 | 633.280 | 1122.848 | 0.854 | 0.564 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.712 | 363.232 | 940.448 | 1689.440 | 0.935 | 0.557 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 341.056 | 355.264 | 940.128 | 1641.152 | 0.960 | 0.573 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.736 | 741.024 | 1569.824 | 2559.552 | 0.622 | 0.613 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 293.856 | 464.192 | 1066.240 | 1840.416 | 0.633 | 0.579 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.704 | 753.152 | 1570.112 | 2641.088 | 0.612 | 0.594 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.832 | 745.536 | 1570.144 | 2602.560 | 0.618 | 0.603 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.680 | 41.280 | 171.840 | 158.176 | 0.864 | 1.086 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 31.360 | 42.976 | 158.912 | 139.264 | 0.730 | 1.141 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.168 | 41.600 | 171.648 | 161.344 | 0.845 | 1.064 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.136 | 41.152 | 171.808 | 158.336 | 0.854 | 1.085 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.832 | 76.384 | 295.680 | 277.696 | 0.639 | 1.065 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.632 | 72.512 | 281.760 | 250.752 | 0.629 | 1.124 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 49.504 | 76.608 | 295.584 | 279.712 | 0.646 | 1.057 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.864 | 75.904 | 295.456 | 277.568 | 0.644 | 1.064 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.392 | 111.232 | 408.640 | 442.656 | 0.894 | 0.923 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 71.392 | 95.168 | 338.784 | 341.760 | 0.750 | 0.991 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.808 | 112.256 | 408.608 | 456.160 | 0.889 | 0.896 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 100.032 | 110.816 | 408.512 | 444.192 | 0.903 | 0.920 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.040 | 226.112 | 726.880 | 774.176 | 0.597 | 0.939 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 99.904 | 169.696 | 616.448 | 607.104 | 0.589 | 1.015 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.488 | 228.384 | 727.776 | 782.368 | 0.593 | 0.930 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.744 | 225.664 | 728.000 | 773.600 | 0.602 | 0.941 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1324.192 | 1387.808 | 3866.944 | 5217.184 | 0.954 | 0.741 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 738.464 | 832.608 | 2507.392 | 3146.688 | 0.887 | 0.797 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.016 | 1404.256 | 3867.872 | 5382.624 | 0.944 | 0.719 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.144 | 1386.688 | 3867.552 | 5203.264 | 0.956 | 0.743 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1847.488 | 2866.336 | 6612.704 | 8597.696 | 0.645 | 0.769 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1066.592 | 1660.640 | 4357.696 | 5174.016 | 0.642 | 0.842 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1850.464 | 2905.408 | 6616.928 | 8793.280 | 0.637 | 0.752 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1848.896 | 2834.720 | 6623.872 | 8637.920 | 0.652 | 0.767 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.384 | 38.656 | 150.336 | 182.624 | 0.941 | 0.823 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.360 | 38.112 | 137.664 | 171.840 | 0.823 | 0.801 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.608 | 39.040 | 150.528 | 183.872 | 0.938 | 0.819 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.064 | 38.656 | 150.560 | 183.520 | 0.933 | 0.820 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.344 | 76.352 | 253.920 | 301.440 | 0.646 | 0.842 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 46.720 | 65.824 | 239.424 | 296.384 | 0.710 | 0.808 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.248 | 76.416 | 253.728 | 307.808 | 0.644 | 0.824 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.376 | 76.288 | 253.728 | 304.736 | 0.647 | 0.833 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.144 | 364.960 | 503.072 | 0.901 | 0.725 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 71.136 | 92.384 | 294.432 | 393.056 | 0.770 | 0.749 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.200 | 111.360 | 365.152 | 512.640 | 0.891 | 0.712 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.240 | 365.088 | 504.224 | 0.900 | 0.724 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.680 | 230.336 | 613.472 | 816.896 | 0.589 | 0.751 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 100.256 | 165.088 | 502.144 | 676.480 | 0.607 | 0.742 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.008 | 232.480 | 613.184 | 836.672 | 0.581 | 0.733 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.232 | 230.624 | 613.536 | 827.136 | 0.586 | 0.742 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1324.064 | 1378.688 | 3631.808 | 5308.384 | 0.960 | 0.684 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 731.776 | 826.688 | 2263.168 | 3241.344 | 0.885 | 0.698 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1316.128 | 1403.200 | 3625.088 | 5550.688 | 0.938 | 0.653 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1311.904 | 1378.880 | 3616.320 | 5353.696 | 0.951 | 0.675 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1837.856 | 2887.392 | 6121.632 | 8586.656 | 0.637 | 0.713 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1066.976 | 1654.368 | 3843.136 | 5291.040 | 0.645 | 0.726 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1854.208 | 2896.832 | 6130.112 | 8745.984 | 0.640 | 0.701 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1860.512 | 2889.344 | 6135.648 | 8750.592 | 0.644 | 0.701 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.640 | 71.552 | 315.968 | 296.512 | 0.847 | 1.066 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 50.784 | 71.040 | 284.288 | 258.880 | 0.715 | 1.098 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 61.312 | 72.704 | 315.680 | 302.016 | 0.843 | 1.045 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.800 | 71.776 | 316.320 | 297.152 | 0.847 | 1.065 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.576 | 144.416 | 580.576 | 535.936 | 0.586 | 1.083 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.064 | 123.648 | 553.344 | 481.376 | 0.615 | 1.150 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.160 | 145.248 | 581.024 | 540.000 | 0.579 | 1.076 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.512 | 143.552 | 581.088 | 535.776 | 0.589 | 1.085 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.152 | 209.408 | 798.400 | 868.704 | 0.903 | 0.919 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 127.552 | 168.800 | 650.816 | 663.328 | 0.756 | 0.981 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.376 | 211.360 | 798.080 | 895.552 | 0.896 | 0.891 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.440 | 208.576 | 797.888 | 873.152 | 0.908 | 0.914 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 257.536 | 441.760 | 1408.960 | 1514.720 | 0.583 | 0.930 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 179.328 | 312.096 | 1170.368 | 1177.472 | 0.575 | 0.994 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 259.264 | 446.944 | 1408.768 | 1530.400 | 0.580 | 0.921 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 258.080 | 440.480 | 1408.864 | 1514.144 | 0.586 | 0.930 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.808 | 2771.456 | 7616.704 | 10405.248 | 0.937 | 0.732 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1435.744 | 1610.336 | 4927.520 | 6220.000 | 0.892 | 0.792 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.264 | 2745.056 | 7611.232 | 10631.392 | 0.945 | 0.716 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2576.256 | 2735.456 | 7626.400 | 10346.976 | 0.942 | 0.737 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.744 | 5634.816 | 13077.056 | 17182.528 | 0.653 | 0.761 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2099.360 | 3250.176 | 8589.664 | 10236.672 | 0.646 | 0.839 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3676.800 | 5716.288 | 13073.088 | 17311.071 | 0.643 | 0.755 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.136 | 5570.496 | 13070.720 | 17192.863 | 0.660 | 0.760 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.600 | 71.008 | 272.320 | 300.000 | 0.868 | 0.908 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 50.176 | 65.344 | 241.568 | 258.912 | 0.768 | 0.933 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.120 | 72.512 | 272.672 | 305.408 | 0.843 | 0.893 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.248 | 71.136 | 272.640 | 301.120 | 0.861 | 0.905 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.872 | 146.784 | 466.912 | 496.832 | 0.571 | 0.940 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.072 | 435.584 | 462.112 | 0.667 | 0.943 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.392 | 147.392 | 466.656 | 504.448 | 0.566 | 0.925 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.360 | 146.688 | 466.656 | 499.040 | 0.568 | 0.935 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.024 | 207.584 | 684.768 | 873.568 | 0.911 | 0.784 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 126.944 | 164.288 | 536.192 | 645.984 | 0.773 | 0.830 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 188.768 | 209.760 | 684.096 | 897.504 | 0.900 | 0.762 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.408 | 207.776 | 685.024 | 876.384 | 0.912 | 0.782 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 259.168 | 449.536 | 1167.936 | 1433.280 | 0.577 | 0.815 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 180.000 | 305.312 | 928.000 | 1113.920 | 0.590 | 0.833 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 258.464 | 455.136 | 1167.808 | 1462.848 | 0.568 | 0.798 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 257.824 | 450.208 | 1167.744 | 1448.000 | 0.573 | 0.806 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2598.368 | 2729.120 | 7134.400 | 10381.632 | 0.952 | 0.687 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1435.456 | 1591.040 | 4424.768 | 6035.808 | 0.902 | 0.733 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2594.752 | 2725.952 | 7128.384 | 10822.496 | 0.952 | 0.659 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2597.888 | 2716.960 | 7101.568 | 10385.440 | 0.956 | 0.684 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3647.648 | 5581.632 | 12089.952 | 16667.233 | 0.654 | 0.725 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2093.952 | 3241.440 | 7579.392 | 9847.936 | 0.646 | 0.770 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3650.528 | 5650.688 | 12105.568 | 16963.680 | 0.646 | 0.714 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3680.064 | 5585.312 | 12117.504 | 16935.040 | 0.659 | 0.716 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505 Approved by: https://github.com/Chillee |
||
|
|
348d02a983 |
Changed masked out rows logsumexp to be -inf and not zero (#134650)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134650 Approved by: https://github.com/yanboliang, https://github.com/BoyuanFeng, https://github.com/drisspg |
||
|
|
d966d91e37 |
[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538 Approved by: https://github.com/yanboliang ghstack dependencies: #134507, #134511 |
||
|
|
3e10a1eb5a |
Revert "[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)"
This reverts commit
|
||
|
|
a34320a6f2 |
[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538 Approved by: https://github.com/yanboliang ghstack dependencies: #134495, #134507, #134511 |
||
|
|
bf5addb613 |
[FlexAttention] Enable different qk and v head-dims (#134043)
# Summary Adds the option for the head dims to be different between QK and V tensors. Fixes issue: https://github.com/pytorch/pytorch/issues/133674 V_DIM > QK_DIM is blocked by landing: https://github.com/triton-lang/triton/pull/4138 / https://github.com/triton-lang/triton/pull/4540 Into PyTorch's triton branch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134043 Approved by: https://github.com/Chillee |
||
|
|
629bd6f718 |
Update FlexAttention with masking semantic (#133373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373 Approved by: https://github.com/yanboliang |
||
|
|
1a7e8e5780 |
Revert "Update FlexAttention with masking semantic (#133373)"
This reverts commit
|
||
|
|
88c973005d |
Revert "[FlexAttention] Enable different qk and v head-dims (#134043)"
This reverts commit
|
||
|
|
e847b6bb9b |
[FlexAttention] Enable different qk and v head-dims (#134043)
# Summary Adds the option for the head dims to be different between QK and V tensors. Fixes issue: https://github.com/pytorch/pytorch/issues/133674 V_DIM > QK_DIM is blocked by landing: https://github.com/triton-lang/triton/pull/4138 / https://github.com/triton-lang/triton/pull/4540 Into PyTorch's triton branch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134043 Approved by: https://github.com/Chillee |
||
|
|
5a7b544e5c |
Update FlexAttention with masking semantic (#133373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373 Approved by: https://github.com/yanboliang |
||
|
|
bc785c2d9a |
[Inductor][FlexAttention] Don't trigger dynamic shape on building empty block mask (#133836)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133836 Approved by: https://github.com/Chillee |
||
|
|
04f37ed57d |
Add support for returning LSE from FlexAttention (and also differentiating through it) (#133159)
This PR changes the "contract" of `flex_attention_hop` to return LSE in base 2. However, we undo that and return LSE in base e from the `flex_attention` frontend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133159 Approved by: https://github.com/yanboliang |
||
|
|
e888f401c5 |
Fix autotuning for flex_decoding (#132157)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132157 Approved by: https://github.com/drisspg, https://github.com/yanboliang ghstack dependencies: #131559 |
||
|
|
4110cb6ba7 |
Add explicit GQA support. (#131559)
### tl;dr
This PR adds GQA support to higher order op `flex_attention`.
## Details
When `enable_gqa` is set to True, HOP `flex_attention(score_mod, query, key, value, block_mask, enable_gqa)` runs Group Query Attention(GQA), where the number of query heads (Hq) is a multiple of number of key/value heads (Hkv). Each group of query heads (`Hq//Hkv` heads) attends to a shared kv head.
Otherwise, `flex_attention` assumes Multi Head Attention (MHA) where the number of query heads is equal the number of kv heads.
The `score_mod` and `mask_mod` API are adapted accordingly to take `q_head` as head index.
```
def score_mod(score: torch.Tensor, batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
def mask_mod(batch: torch.Tensor, q_head: torch.Tensor, token_q: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```
## Example
```python
import torch
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
torch.manual_seed(0)
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim).
# query and key/value can have different num_heads and seq_len
# Here 8 query heads share one KV head.
query = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
key = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
value = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float32, requires_grad=True)
query1, key1, value1 = query_key_value_clones(query, key, value)
# Lets create a score_modification. We take alibi_bias as an example.
# score_mod takes batch index, query head index, query index, and key/value index.
def _generate_alibi_bias(num_kv_heads: int, num_q_heads: int):
def _alibi_bias(
score: torch.Tensor,
b: torch.Tensor,
hq: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
# Let's calculate kv head from query head index
group = num_q_heads // num_kv_heads
hkv = hq // group
scale = torch.exp2(-((hkv + 1) * 8.0 / num_kv_heads))
return score + (token_kv - token_q) * scale
return _alibi_bias
# Let's apply a casual mask on top of it
def causal_mask(b, h, q, kv):
return q >= kv
# Generate a block mask for our new mask_mod function.
# The mask is broadcasted long head & batch dimensions.
block_mask = create_block_mask(causal_mask, B=1, H=1, Q_LEN=2048, KV_LEN=2048)
# Lets call flex_attention with our new score modification and block mask under eager mode.
output = flex_attention(query, key, value, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
# Now lets compile flex_attention and run the flex_attention kernel.
compiled_flex_attention = torch.compile(flex_attention)
out_compiled = compiled_flex_attention(query1, key1, value1, score_mod=_generate_alibi_bias(2, 8), block_mask=block_mask, enable_gqa=True)
torch.testing.assert_close(output, out_compiled, atol=5e-2, rtol=2e-2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131559
Approved by: https://github.com/drisspg
|
||
|
|
bbce517221 |
[Inductor][FlexAttention] TestFlexAttention -> TestFlexDecoding (#132547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132547 Approved by: https://github.com/Chillee ghstack dependencies: #132015 |
||
|
|
373e9be457 |
[Inductor][FlexAttention] Add kwarg to top level for users to specify kernel params (#132015)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132015 Approved by: https://github.com/Chillee |
||
|
|
bdd83c4c7f |
Add Full block support to flex_decoding (#131404)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131404 Approved by: https://github.com/yanboliang |
||
|
|
7f61324268 |
Add sparse block to flex_decoding kernel (#130884)
fix typo Finish flex_decoding block sparse Pull Request resolved: https://github.com/pytorch/pytorch/pull/130884 Approved by: https://github.com/drisspg |
||
|
|
2b83e4f8d7 |
[ROCm] Enable flex decoding unit tests (#131048)
Flex decoding tests are passing with upstream pytorch on MI300X/MI2XX. Only flex attention unit tests have issues. [result_mi250.log](https://github.com/user-attachments/files/16286954/result_mi250.log) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131048 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/malfet |
||
|
|
6cbb1437c1 |
Revert "Add sparse block to flex_decoding kernel (#130884)"
This reverts commit
|
||
|
|
0bf59db6cc |
Add sparse block to flex_decoding kernel (#130884)
fix typo Finish flex_decoding block sparse Pull Request resolved: https://github.com/pytorch/pytorch/pull/130884 Approved by: https://github.com/drisspg |
||
|
|
dd39dca034 |
Removing some cruff and updating signatures for consistency (#130871)
# Summary - This removes a bunch of example score mods that were primarily used for testing and places them directly in the test file. We should follow up with merging test_flex_decode and test_flash when the velocity slows down a little - Fixes a bug with indexing on block mask - Adds some doc strings to helper funcs and fixes some misc typing things - Forces functions passed to `create_block_mask` to mask_mods and updates tests files Pull Request resolved: https://github.com/pytorch/pytorch/pull/130871 Approved by: https://github.com/joydddd, https://github.com/Chillee |