Commit Graph

71 Commits

Author SHA1 Message Date
Xinya Zhang
cef815dc2c [ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998)
A few UT failures are caused by `HIPBLASLT_ALLOW_TF32`

Fixes #157094, #157093, #157092, #157091, #157064, #157063, #157062, #157061, #157042, #157041, #157039, #157004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-09-16 12:48:45 +00:00
Wang, Eikan
a3d72b09ae Apply Triton tensor descriptor for flex-decoding for performance (#161643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161643
Approved by: https://github.com/drisspg
2025-09-04 20:10:41 +00:00
William Wen
f36f285953 [dynamo] change error_on_graph_break/fullgraph semantics (#161747)
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161747
Approved by: https://github.com/mlazos
ghstack dependencies: #161739
2025-09-04 17:10:17 +00:00
Tianren Gao
2fed4fb464 [FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861)
Fixes #159247
Issue 1: Accuracy Problem with Non-Divisible KV Sequences
---------------------------------------------------------

### Background

Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy.

### Root Cause
The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue.

### Solution

*   **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch

### Files Modified
*    `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod.
*   `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests
*   `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests

Issue 2: Invalid Memory Access (IMA) in Triton Kernels
------------------------------------------------------

### Background

The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes.

### Root Cause

*   Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`)
*   With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values
*   This caused out-of-bounds memory access violations

### Solution

Implemented boundary checks with early exit:

1.  **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py`

    *   Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template
2.  **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja`

    *   Boundary checks before accessing `kv_indices` in both normal and full blocks
    *   Idle CTAs with invalid `indices_idx` skip computation entirely

This prevents invalid memory access while reducing wasted computation on idle thread blocks.

Testing & Validation
--------------------

### Accuracy Tests

*   Added comprehensive test cases covering KV sequences not divisible by block sizes
*   Verified output matches standard attention for various sequence length combinations

### Sanitizer Results

`========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors`

**Before**: More than 13720 invalid memory access errors with sanitizers
**After**: Clean execution with 0 errors

Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861
Approved by: https://github.com/BoyuanFeng
2025-08-30 04:50:23 +00:00
Zhang, Liangang
3e459491b5 Enable XPU path for FlexAttention (#143553)
[#RFC153024](https://github.com/pytorch/pytorch/issues/153024)

**Motivation**

1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA  in the transformers repo on XPU device. Besides,  it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device.
2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and  different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256.

**What does this PR do?**

 1. Enable the device type for Flexattention kernel  and UTs to ensure all important UTs pass on XPU device.
 2. For E2E model inference, ensure the functionality  of LLM models inference with FlexAttention to be ready.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143553
Approved by: https://github.com/EikanWang, https://github.com/drisspg

Co-authored-by: Mao Yunfei <yunfei.mao@intel.com>
Co-authored-by: Xingyuan Li <xingyuan.li@intel.com>
Co-authored-by: majing <jing1.ma@intel.com>
Co-authored-by: Xiao, Wang <wang.xiao@intel.com>
2025-08-29 23:10:58 +00:00
Angel Li
8e17709055 FlexDecode not guarding on GQA groups correctly (#160904)
Addressing #151359

Updates flex_decode dispatch to use flex attention rather than flex decode if number of groups is not a power of 2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160904
Approved by: https://github.com/drisspg
2025-08-20 16:32:16 +00:00
Xu Han
21392c0e06 [inductor] disable flex decoding on Windows. (#160072)
Discussed with @jianan-gu and @Valentine233 , disable flex decoding on Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160072
Approved by: https://github.com/angelayi
2025-08-07 18:07:36 +00:00
Eddie Yan
19aa8eb4f5 [TF32][Flex Attention] Turn off TF32 for reference computation in test_flex_decoding (#158979)
Seems to avoid threshold (fudge factor) twiddling games as this causes the checks to go down the "very small ref error" path instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158979
Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng, https://github.com/nWEIdia
2025-07-28 18:38:23 +00:00
Catherine Lee
561193e5f2 [CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)
3 procs were used for sm86, but we switched to sm89 and the check failed so it switched back to 2

sm90 is H100, but idk what unittests we have running there, but I assume they also have a lot of memory

They use larger runners, which have more GPU memory, so its usually ok.  I think it's ~22GB -> 10GB per proc if 2, 6GB per proc if 3 (cuda context maybe 1GB)

I've applied skips to the ones that OOMed

Time decreases from ~2.7hr per test job -> ~2hr

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158691
Approved by: https://github.com/huydhn
2025-07-25 15:26:29 +00:00
PyTorch MergeBot
11ea3736dd Revert "[CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)"
This reverts commit 0c0fcb53ff.

Reverted https://github.com/pytorch/pytorch/pull/158691 on behalf of https://github.com/ZainRizvi due to Sorry but these are causing jobs to fail with out of memory errors on trunk ([comment](https://github.com/pytorch/pytorch/pull/158691#issuecomment-3113922186))
2025-07-24 15:31:53 +00:00
Catherine Lee
0c0fcb53ff [CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)
3 procs were used for sm86, but we switched to sm89 and the check failed so it switched back to 2

sm90 is H100, but idk what unittests we have running there, but I assume they also have a lot of memory

They use larger runners, which have more GPU memory, so its usually ok.  I think it's ~22GB -> 10GB per proc if 2, 6GB per proc if 3 (cuda context maybe 1GB)

I've applied skips to the ones that OOMed

Time decreases from ~2.7hr per test job -> ~2hr

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158691
Approved by: https://github.com/huydhn
2025-07-24 01:51:28 +00:00
Xuehai Pan
17687eb792 [BE][4/6] fix typos in test/ (test/inductor/) (#157638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157638
Approved by: https://github.com/yewentao256, https://github.com/jansel
2025-07-06 06:34:25 +00:00
William Wen
1c3f5e902d [dynamo] control one_graph behavior additionally through config (#154283)
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283
Approved by: https://github.com/jansel, https://github.com/anijain2305
2025-06-26 21:40:38 +00:00
haozhe.zhu
53e0b9c393 refine fp32 precision api (#125888)
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32  internal computation data types . Instead, we will directly use the algorithm to represent it.

### Design Choice: Directly use algorithms name like "TF32", "BF16".
#### Pros
 - The names are more informative. 'tf32' is more informative than a simple "high".
 - Easier to extend new algorithm like `tf32x3`
#### Cons
 - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.

### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')
![image](https://github.com/user-attachments/assets/f89143e5-d6a1-4865-9351-9a50439f5067)

### We provide 3 fp32 compute precision can be set:
 - **"ieee"**: Not allowed to use any other internal computation data types .
 - **"tf32"**: Allowed to use tf32 as internal computation data types.
 - **"bf16"**: Allowed to use bf16 as internal computation data types.
 - **"none"**:  Precision's are not set. Can be override by its father node.

### Overriding Precision Settings
Child node can be override by its father node if it is set to default.
For current default settings:
```
backend = generic, op = all, precision setting = none
    backend = cuda, op = all, precision setting = none
        backend = cuda, op = conv, precision setting = tf32
        backend = cuda, op = rnn, precision setting = tf32
        backend = cuda, op = matmul, precision setting = none
    backend = matmul, op = all, precision setting = none
        backend = matmul, op = conv, precision setting = none
        backend = matmul, op = rnn, precision setting = none
        backend = matmul, op = matmul, precision setting = none
```
 - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16".
 - If the user set `torch.backends.fp32_precision="bf16"`,  `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16".

### Backward Compatible
Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is
 - If the user only uses previous APIs, it will work as previous expectations.
 - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user.

### Test Plan
```
python test/test_cuda.py -k test_fp32_precision_with_tf32
python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision
python test/test_cuda.py -k test_invalid_status_for_legacy_api
python test/test_mkldnn.py -k test_mlkdnn_get_set
python test/test_mkldnn.py -k test_generic_precision
python test/test_mkldnn.py -k test_invalid
python test/test_mkldnn.py -k test_default_use_parent
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888
Approved by: https://github.com/jgong5, https://github.com/albanD

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
2025-06-26 10:32:20 +00:00
Xuehai Pan
f5e6e52f25 [BE][PYFMT] migrate PYFMT for test/inductor/ to ruff format (#148186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148186
Approved by: https://github.com/jansel
2025-06-24 11:12:11 +00:00
PyTorch MergeBot
b5c8b8d09f Revert "[dynamo] control one_graph behavior additionally through config (#154283)"
This reverts commit b46eb1ccaf.

Reverted https://github.com/pytorch/pytorch/pull/154283 on behalf of https://github.com/ezyang due to All of this is responsible for regression, see https://github.com/pytorch/pytorch/pull/156561 ([comment](https://github.com/pytorch/pytorch/pull/154283#issuecomment-2994242583))
2025-06-22 14:22:07 +00:00
William Wen
b46eb1ccaf [dynamo] control one_graph behavior additionally through config (#154283)
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283
Approved by: https://github.com/jansel, https://github.com/anijain2305
2025-06-20 07:02:57 +00:00
Joel Schlosser
ae53510b9e Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test

Update: this was not as straightforward as I originally imagined. Context below.

**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.

**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.

Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.

(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:

070f389745/test/test_ops.py (L218-L221)

(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:

070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)

070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)

To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).

**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)

To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
2025-04-16 02:18:42 +00:00
jianan-gu
12cb11a268 [Inductor UT] Refactor FlexAttention UT and add CPU tests (#144953)
This PR extends and refines all rest UTs for CPU and more devices in `test/inductor/test_flex_attention.py`  and `test/inductor/test_flex_decoding.py`, as a follow-up to https://github.com/pytorch/pytorch/pull/141453

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144953
Approved by: https://github.com/drisspg
2025-04-15 12:44:49 +00:00
Sampsa
5e9f792479 [ROCm] Unskip flex attention UTs after triton 3.3 bump (#148327)
Enable `test_flex_attention.py::TestLearnableBiases` unit tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148327
Approved by: https://github.com/jeffdaily
2025-03-17 20:15:14 +00:00
Boyuan Feng
98b3f1db9f [Flex Attention] support num_heads > 1 in block_mask (#148857)
Previously flex decoding errors when block mask has num_heads > 1. So users have to use num_heads=1, or explicitly mark `kernel_options={"FORCE_USE_FLEX_ATTENTION": True}`.

This PR fixes this issue. When not using grouped query attention (GQA, i.e., Hq == Hkv), we support block mask with num_heads = 1 and num_heads = num_query_heads (i.e., Hq). This is the same setting as flex attention kernel.

When using GQA (i.e., Hq != Hkv), we support block mask with num_heads = 1. When num_heads = Hq, we fall back to flex attention kernel so user don't need to explicitly mark `kernel_options={"FORCE_USE_FLEX_ATTENTION": True}` anymore.

Why fallback? In the current flex decoding triton kernel, grouped query heads for the same kv head are handled by the same thread block. Supporting num_heads = Hq with GQA requires support different kv num blocks for different query heads in the same thread block, leading to lots of redundant workload. So we should better use the main flex_attention kernel where each query head is handled by a separate block.

Fixes #148527
Fixes #147267

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148857
Approved by: https://github.com/drisspg
2025-03-10 22:02:50 +00:00
Boyuan Feng
ba9ed856e0 [FlexAttention] Improve error msg for embedding < 16 (#147765)
flex_attention uses tl.dot, which [does not support embedding < 16](https://github.com/triton-lang/triton/issues/2266) on input shapes. This PR adds explicit error message for users who are prototyping with small tensors.

Fixes #147701

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147765
Approved by: https://github.com/drisspg
2025-02-26 17:06:35 +00:00
drisspg
c6707734de Enable non power of 2 head_dim for FlexAttention (#133495)
# Summary
- Adds support for non-power of 2 headdim by launching blocks w/ head_dim rounded to the next valid power.
- Other option I considered was building up the final dot_products with smaller blocks (this would probably work but for sake of code complexity going with this option for now)

### Corollary
We had a bug in our backwards kernel where we were using index_k instead of index_v. This should have shown up for the qk_head_dim != v_head_dim cases..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133495
Approved by: https://github.com/Chillee
2025-01-23 17:05:38 +00:00
Aaron Orenstein
99dbc5b0e2 PEP585 update - test (#145176)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176
Approved by: https://github.com/bobrenjc93
2025-01-22 04:48:28 +00:00
iupaikov-amd
577708e6de Unskipped multiple inductor tests for ROCm (#143581)
All of them should be fine to run now after the triton fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143581
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-01-16 20:46:06 +00:00
Boyuan Feng
069419569d [PagedAttention] Support different input position for each batch index (#144693)
In LLM inference, each request usually has different prefill length, leading to different input position for each batch index. This PR adds such support for paged attention.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144693
Approved by: https://github.com/drisspg
2025-01-15 18:03:52 +00:00
PyTorch MergeBot
7d9f26de05 Revert "Unskipped multiple inductor tests for ROCm (#143581)"
This reverts commit e05d67790e.

Reverted https://github.com/pytorch/pytorch/pull/143581 on behalf of https://github.com/huydhn due to There is some tests failing on ROCm jobs in trunk ([comment](https://github.com/pytorch/pytorch/pull/143581#issuecomment-2577163274))
2025-01-08 09:15:14 +00:00
iupaikov-amd
e05d67790e Unskipped multiple inductor tests for ROCm (#143581)
All of them should be fine to run now after the triton fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143581
Approved by: https://github.com/jataylo, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-01-08 03:55:33 +00:00
Tom Ritchford
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
Joy Dong
ed9931e6ee Add tests for non divisible inputs for flex decoding (#143214)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143214
Approved by: https://github.com/drisspg
2024-12-18 16:32:45 +00:00
Xinya Zhang
424156c26c [ROCm] Update to AOTriton 0.8b (#140172)
Notable new features for SDPA operators on AMD systems from AOTriton 0.8b:

1. Nestedtensor support;
2. MQA/GQA support;
3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases;
    + The kernel should use top-left alignment, bottom right alignment will be added later
4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status.
   However, users are strongly recommended to update to ROCM 6.2.4, notably for
   its firmware updates.

Related unit tests are enabled as well.

Notable related changes from AOTriton 0.8b:

1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`;
2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB
3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency.
    + Should not be a problem, since `lzma` is part of Python Standard Library

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2024-12-06 21:45:18 +00:00
chilli
af88326250 Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)
Fixes https://github.com/pytorch/pytorch/issues/141435

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625
Approved by: https://github.com/drisspg
ghstack dependencies: #138788
2024-12-03 04:45:05 +00:00
PyTorch MergeBot
a34a56f69f Revert "Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)"
This reverts commit 795f28ac55.

Reverted https://github.com/pytorch/pytorch/pull/141625 on behalf of https://github.com/albanD due to Broken main ([comment](https://github.com/pytorch/pytorch/pull/141625#issuecomment-2511639687))
2024-12-02 14:10:38 +00:00
chilli
795f28ac55 Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)
Fixes https://github.com/pytorch/pytorch/issues/141435

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625
Approved by: https://github.com/drisspg
ghstack dependencies: #138788
2024-12-02 00:35:29 +00:00
zeshengzong
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
drisspg
540f3ef9b1 Fix flex_decode to build offsets off of strides (#139516)
Fixes PR: https://github.com/pytorch/pytorch/issues/139462

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139516
Approved by: https://github.com/Chillee
2024-11-02 03:17:46 +00:00
Boyuan Feng
68134a320e [Flex Attention] Paged Attention (#137164)
This PR adds paged attention for flex attention.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137164
Approved by: https://github.com/drisspg
2024-10-29 17:05:22 +00:00
drisspg
34c18887ad [FlexAttention] Remove restriction on QK headdim > V headdim (#135884)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135884
Approved by: https://github.com/Chillee
2024-10-01 21:17:54 +00:00
Boyuan Feng
28224329ad [Flex Attention] fix block size order (#136657)
`create_block_mask` currently gives wrong BLOCK_SIZE and shape when using non-default block size `(128,128)`.
This PR fixes the issue by using BLOCK_SIZE order `(Q_BLOCK_SIZE, KV_BLOCK_SIZE)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136657
Approved by: https://github.com/Chillee, https://github.com/drisspg
2024-09-28 19:56:53 +00:00
PyTorch MergeBot
f7ab0e9989 Revert "[Flex Attention] fix block size order (#136657)"
This reverts commit b42f1e3641.

Reverted https://github.com/pytorch/pytorch/pull/136657 on behalf of https://github.com/ZainRizvi due to Sorry, this seems to break ROCm builds. inductor/test_flex_attention.py::TestFlexAttention::test_builtin_score_mods_seqlen_lt_custom_sparse_block_size_float16_score_mod1 [GH job link](https://github.com/pytorch/pytorch/actions/runs/11069782242/job/30759299713) [HUD commit link](b42f1e3641) ([comment](https://github.com/pytorch/pytorch/pull/136657#issuecomment-2380031525))
2024-09-27 20:47:54 +00:00
Boyuan Feng
b42f1e3641 [Flex Attention] fix block size order (#136657)
`create_block_mask` currently gives wrong BLOCK_SIZE and shape when using non-default block size `(128,128)`.
This PR fixes the issue by using BLOCK_SIZE order `(Q_BLOCK_SIZE, KV_BLOCK_SIZE)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136657
Approved by: https://github.com/Chillee, https://github.com/drisspg
2024-09-27 11:26:47 +00:00
Jerry Mannil
ccca3de0cd [ROCm] Enable Flex attention tests on AMD gpus (#136245)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136245
Approved by: https://github.com/malfet
2024-09-19 18:02:41 +00:00
Xinya Zhang
74fd1bf965 [ROCm] Update to AOTriton 0.7b (#134498)
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`

Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
    + Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`

Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.

Fixes #133540

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
2024-09-11 20:34:01 +00:00
Boyuan Feng
6e13f5eb38 [FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension.

## Details
Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention.

This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward.

## Benchmark
GPU: H100

We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`.

```
python benchmarks/transformer/score_mod.py --calculate-bwd
```
### Perf before this PR

**FWD**

| Type    |   Speedup | score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)        |
|---------|-----------|---------------|------------|----------------|------------------------------|
| Average |     0.743 |               |            |                |                              |
| Max     |     0.955 | head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)   |
| Min     |     0.548 | relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) |

**BWD**

| Type    |   Speedup | score_mod   | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)       |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average |     0.834 |             |            |                |                             |
| Max     |     1.261 | head_bias   | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)   |
| Min     |     0.456 | None        | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |

<details>
<summary> Full performance sweep </summary>

| score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)         |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.264 |              17.184 |          107.040 |             140.800 |         0.888 |         0.760 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.840 |              19.744 |          112.576 |             140.064 |         0.802 |         0.804 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.232 |              17.344 |           87.744 |             142.496 |         0.878 |         0.616 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.264 |              17.184 |          108.192 |             143.328 |         0.888 |         0.755 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.904 |              22.400 |          106.432 |             136.512 |         0.889 |         0.780 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.424 |              26.752 |           91.712 |             106.688 |         0.726 |         0.860 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.808 |              22.432 |           89.024 |             101.920 |         0.883 |         0.873 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.840 |              22.272 |           88.896 |             102.592 |         0.891 |         0.867 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.240 |              32.416 |          116.768 |             112.256 |         0.933 |         1.040 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           29.536 |              37.024 |          113.664 |             102.688 |         0.798 |         1.107 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.656 |              32.800 |          116.992 |             127.008 |         0.935 |         0.921 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.592 |              32.480 |          116.928 |             112.160 |         0.942 |         1.043 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           40.448 |              61.920 |          198.656 |             204.512 |         0.653 |         0.971 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           37.760 |              62.528 |          189.536 |             170.624 |         0.604 |         1.111 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           40.896 |              62.368 |          198.304 |             205.824 |         0.656 |         0.963 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           40.448 |              61.952 |          198.432 |             203.648 |         0.653 |         0.974 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          318.528 |             355.904 |          947.232 |            1162.496 |         0.895 |         0.815 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          199.776 |             252.128 |          677.792 |             813.184 |         0.792 |         0.834 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          316.512 |             363.328 |          947.712 |            1361.984 |         0.871 |         0.696 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          317.984 |             356.864 |          947.264 |            1165.024 |         0.891 |         0.813 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          446.656 |             734.656 |         1664.288 |            2172.960 |         0.608 |         0.766 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          278.688 |             467.648 |         1182.624 |            1339.296 |         0.596 |         0.883 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          447.872 |             744.096 |         1662.944 |            2196.544 |         0.602 |         0.757 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          448.128 |             732.928 |         1663.072 |            2156.800 |         0.611 |         0.771 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.648 |              16.640 |          107.520 |             143.008 |         0.940 |         0.752 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.776 |              18.240 |          129.056 |             141.920 |         0.865 |         0.909 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.168 |              16.640 |          103.616 |             139.648 |         0.912 |         0.742 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.616 |              16.640 |          128.608 |             164.448 |         0.938 |         0.782 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.776 |              21.952 |          125.344 |             170.304 |         0.901 |         0.736 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.776 |              23.712 |          104.288 |             196.896 |         0.834 |         0.530 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.072 |              21.952 |          102.080 |             177.056 |         0.869 |         0.577 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.648 |              21.920 |          109.920 |             170.848 |         0.896 |         0.643 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.464 |              31.936 |          127.808 |             228.832 |         0.954 |         0.559 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           29.472 |              33.856 |          113.152 |             215.072 |         0.871 |         0.526 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.496 |              32.160 |          116.576 |             231.744 |         0.948 |         0.503 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.464 |              31.904 |          116.320 |             229.824 |         0.955 |         0.506 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           40.480 |              61.440 |          176.448 |             345.312 |         0.659 |         0.511 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           38.304 |              59.424 |          169.312 |             371.360 |         0.645 |         0.456 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           40.960 |              61.760 |          176.512 |             358.912 |         0.663 |         0.492 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           40.352 |              61.696 |          176.512 |             344.928 |         0.654 |         0.512 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          316.224 |             357.728 |          905.728 |            1668.448 |         0.884 |         0.543 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          199.904 |             248.416 |          636.544 |            1109.088 |         0.805 |         0.574 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          314.880 |             363.616 |          906.304 |            1658.176 |         0.866 |         0.547 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          316.160 |             354.368 |          906.080 |            1649.024 |         0.892 |         0.549 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          446.912 |             739.840 |         1555.808 |            2521.952 |         0.604 |         0.617 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          279.776 |             463.904 |         1068.928 |            1849.888 |         0.603 |         0.578 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          446.080 |             748.960 |         1553.504 |            2629.888 |         0.596 |         0.591 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          446.208 |             740.608 |         1558.880 |            2524.960 |         0.602 |         0.617 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           33.568 |              41.280 |          170.016 |             147.584 |         0.813 |         1.152 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           30.688 |              43.040 |          159.552 |             146.720 |         0.713 |         1.087 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           34.112 |              41.504 |          170.112 |             152.672 |         0.822 |         1.114 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           34.240 |              41.152 |          170.272 |             134.976 |         0.832 |         1.261 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.672 |              76.416 |          295.296 |             263.648 |         0.637 |         1.120 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           45.088 |              72.576 |          281.920 |             237.664 |         0.621 |         1.186 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.032 |              76.672 |          295.520 |             265.248 |         0.626 |         1.114 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.096 |              76.096 |          295.456 |             262.112 |         0.632 |         1.127 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           93.920 |             111.232 |          401.568 |             382.944 |         0.844 |         1.049 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           68.192 |              95.232 |          338.752 |             326.816 |         0.716 |         1.037 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           93.984 |             111.840 |          401.856 |             444.224 |         0.840 |         0.905 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           94.176 |             110.496 |          401.600 |             383.136 |         0.852 |         1.048 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          131.488 |             227.040 |          727.424 |             739.712 |         0.579 |         0.983 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |           95.616 |             169.760 |          616.864 |             574.112 |         0.563 |         1.074 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          131.680 |             228.672 |          727.616 |             746.048 |         0.576 |         0.975 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          131.104 |             225.696 |          727.904 |             735.392 |         0.581 |         0.990 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1227.296 |            1386.656 |         3720.192 |            4539.904 |         0.885 |         0.819 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |          691.360 |             831.712 |         2515.872 |            3067.808 |         0.831 |         0.820 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1228.192 |            1403.136 |         3715.520 |            5309.280 |         0.875 |         0.700 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1229.024 |            1384.992 |         3715.904 |            4550.368 |         0.887 |         0.817 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1784.832 |            2865.888 |         6539.840 |            8460.224 |         0.623 |         0.773 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1017.408 |            1660.480 |         4369.824 |            5056.992 |         0.613 |         0.864 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1792.448 |            2904.864 |         6546.080 |            8537.024 |         0.617 |         0.767 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1795.552 |            2856.864 |         6544.672 |            8400.160 |         0.629 |         0.779 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           34.240 |              38.880 |          148.832 |             179.936 |         0.881 |         0.827 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           31.168 |              38.080 |          138.528 |             167.552 |         0.818 |         0.827 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           34.240 |              39.168 |          148.512 |             181.248 |         0.874 |         0.819 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           34.240 |              38.784 |          148.864 |             180.224 |         0.883 |         0.826 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           48.832 |              76.352 |          253.632 |             295.968 |         0.640 |         0.857 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           45.760 |              65.792 |          239.040 |             290.752 |         0.696 |         0.822 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           48.768 |              76.576 |          253.312 |             304.032 |         0.637 |         0.833 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           48.768 |              76.192 |          253.600 |             296.096 |         0.640 |         0.856 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           93.728 |             109.728 |          357.696 |             498.912 |         0.854 |         0.717 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           68.704 |              92.288 |          295.616 |             386.240 |         0.744 |         0.765 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           93.632 |             111.392 |          357.408 |             512.448 |         0.841 |         0.697 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           93.280 |             109.952 |          357.696 |             501.440 |         0.848 |         0.713 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          131.392 |             230.496 |          612.224 |             807.552 |         0.570 |         0.758 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |           96.512 |             165.184 |          502.624 |             672.384 |         0.584 |         0.748 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          131.360 |             232.608 |          612.064 |             832.320 |         0.565 |         0.735 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          131.008 |             230.528 |          612.640 |             804.320 |         0.568 |         0.762 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1227.968 |            1377.408 |         3477.920 |            5324.384 |         0.892 |         0.653 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |          695.264 |             824.544 |         2268.224 |            3210.208 |         0.843 |         0.707 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1228.640 |            1404.576 |         3476.832 |            5463.456 |         0.875 |         0.636 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1228.416 |            1378.752 |         3478.048 |            5367.712 |         0.891 |         0.648 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1788.736 |            2867.712 |         6039.520 |            8616.256 |         0.624 |         0.701 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1021.952 |            1653.824 |         3866.208 |            5306.848 |         0.618 |         0.729 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1786.752 |            2896.352 |         6044.128 |            8871.360 |         0.617 |         0.681 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1786.080 |            2868.672 |         6040.160 |            8550.144 |         0.623 |         0.706 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           57.504 |              71.552 |          312.768 |             255.040 |         0.804 |         1.226 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           49.472 |              71.104 |          285.696 |             243.520 |         0.696 |         1.173 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           58.112 |              72.896 |          312.768 |             288.256 |         0.797 |         1.085 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           57.952 |              71.680 |          312.768 |             255.552 |         0.808 |         1.224 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           82.336 |             144.256 |          580.128 |             500.160 |         0.571 |         1.160 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           76.160 |             123.712 |          552.544 |             447.648 |         0.616 |         1.234 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           82.400 |             145.184 |          580.032 |             504.032 |         0.568 |         1.151 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           82.368 |             143.904 |          580.192 |             499.936 |         0.572 |         1.161 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          177.216 |             209.568 |          787.872 |             747.712 |         0.846 |         1.054 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          121.984 |             168.256 |          651.968 |             628.256 |         0.725 |         1.038 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          177.088 |             211.488 |          788.320 |             864.352 |         0.837 |         0.912 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          177.440 |             208.576 |          787.424 |             749.120 |         0.851 |         1.051 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          249.472 |             441.376 |         1405.440 |            1431.648 |         0.565 |         0.982 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          172.960 |             312.064 |         1172.064 |            1096.448 |         0.554 |         1.069 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          249.632 |             446.336 |         1405.408 |            1448.480 |         0.559 |         0.970 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          250.944 |             440.128 |         1406.624 |            1421.952 |         0.570 |         0.989 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2418.720 |            2747.936 |         7330.432 |            9023.712 |         0.880 |         0.812 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         1353.696 |            1608.480 |         4941.696 |            6078.752 |         0.842 |         0.813 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2427.456 |            2746.816 |         7329.792 |           10539.968 |         0.884 |         0.695 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2426.688 |            2763.168 |         7336.256 |            9057.536 |         0.878 |         0.810 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3554.240 |            5634.400 |        12919.872 |           16843.489 |         0.631 |         0.767 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         2003.648 |            3250.784 |         8610.144 |           10015.424 |         0.616 |         0.860 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3582.080 |            5710.944 |        12923.328 |           17011.871 |         0.627 |         0.760 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3581.920 |            5618.144 |        12934.528 |           16745.888 |         0.638 |         0.772 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           57.120 |              71.232 |          269.760 |             295.680 |         0.802 |         0.912 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           49.408 |              65.312 |          242.304 |             253.952 |         0.756 |         0.954 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           57.504 |              72.544 |          269.632 |             298.976 |         0.793 |         0.902 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           57.760 |              71.040 |          269.600 |             296.640 |         0.813 |         0.909 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           82.336 |             147.168 |          466.080 |             487.456 |         0.559 |         0.956 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           76.704 |             115.040 |          435.392 |             453.248 |         0.667 |         0.961 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           81.856 |             147.424 |          465.920 |             499.552 |         0.555 |         0.933 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           81.760 |             146.656 |          466.176 |             485.984 |         0.557 |         0.959 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          176.608 |             206.976 |          678.080 |             866.976 |         0.853 |         0.782 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          121.664 |             164.768 |          538.240 |             636.160 |         0.738 |         0.846 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          176.608 |             209.664 |          677.696 |             883.424 |         0.842 |         0.767 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          177.440 |             207.840 |          677.248 |             868.288 |         0.854 |         0.780 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          250.272 |             449.536 |         1163.424 |            1420.832 |         0.557 |         0.819 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          173.472 |             305.376 |          929.408 |            1104.544 |         0.568 |         0.841 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          249.376 |             454.976 |         1163.648 |            1455.296 |         0.548 |         0.800 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          250.368 |             450.144 |         1163.520 |            1409.984 |         0.556 |         0.825 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2416.576 |            2726.208 |         6835.520 |           10442.784 |         0.886 |         0.655 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         1357.440 |            1590.752 |         4433.664 |            5975.296 |         0.853 |         0.742 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2427.360 |            2747.040 |         6853.056 |           10670.784 |         0.884 |         0.642 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2441.120 |            2718.944 |         6836.640 |           10433.792 |         0.898 |         0.655 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3555.392 |            5620.960 |        11944.000 |           16504.801 |         0.633 |         0.724 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         2010.848 |            3241.152 |         7636.064 |            9870.464 |         0.620 |         0.774 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3557.440 |            5688.352 |        11935.744 |           17090.496 |         0.625 |         0.698 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3562.720 |            5630.432 |        11939.168 |           16392.033 |         0.633 |         0.728 |

</details>

### Perf after this PR

**FWD**

| Type    |   Speedup | score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)      |
|---------|-----------|---------------|------------|----------------|----------------------------|
| Average |     0.776 |               |            |                |                            |
| Max     |     1.006 | None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) |
| Min     |     0.566 | relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128) |

**BWD**

| Type    |   Speedup | score_mod   | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)       |
|---------|-----------|-------------|------------|----------------|-----------------------------|
| Average |     0.817 |             |            |                |                             |
| Max     |     1.150 | None        | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 128) |
| Min     |     0.454 | None        | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |

<details>
<summary> Full performance sweep </summary>

| score_mod     | mask_mod   | dtype          | shape(B,Hq,M,Hkv,N,D)         |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.680 |              17.056 |           64.544 |              73.376 |         0.919 |         0.880 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           15.712 |              19.872 |           65.408 |              72.864 |         0.791 |         0.898 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           16.160 |              17.280 |           64.896 |              73.888 |         0.935 |         0.878 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 64)     |           16.192 |              17.120 |           64.896 |              75.424 |         0.946 |         0.860 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.648 |              22.496 |           89.184 |              82.592 |         0.873 |         1.080 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           20.320 |              26.816 |           91.264 |              82.880 |         0.758 |         1.101 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           20.096 |              22.528 |           89.184 |              83.776 |         0.892 |         1.065 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 16, 512, 128)    |           19.680 |              22.432 |           89.184 |             120.096 |         0.877 |         0.743 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           32.384 |              32.512 |          119.232 |             128.960 |         0.996 |         0.925 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           30.176 |              37.248 |          113.664 |             119.520 |         0.810 |         0.951 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           32.512 |              32.928 |          119.264 |             131.456 |         0.987 |         0.907 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64)   |           32.448 |              32.704 |          119.200 |             128.352 |         0.992 |         0.929 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           41.952 |              62.176 |          199.040 |             214.304 |         0.675 |         0.929 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           39.744 |              62.880 |          189.504 |             179.968 |         0.632 |         1.053 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           41.472 |              62.784 |          199.136 |             217.664 |         0.661 |         0.915 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128)  |           42.048 |              61.952 |          199.168 |             214.496 |         0.679 |         0.929 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          341.184 |             357.632 |          980.256 |            1328.896 |         0.954 |         0.738 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          212.576 |             252.960 |          673.888 |             824.864 |         0.840 |         0.817 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          340.000 |             363.296 |          980.768 |            1375.808 |         0.936 |         0.713 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64)   |          340.768 |             356.832 |          980.960 |            1326.272 |         0.955 |         0.740 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          459.392 |             737.120 |         1678.240 |            2205.248 |         0.623 |         0.761 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          292.672 |             468.096 |         1178.016 |            1371.584 |         0.625 |         0.859 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          462.144 |             745.312 |         1680.000 |            2252.512 |         0.620 |         0.746 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128)  |          462.112 |             736.576 |         1679.008 |            2216.480 |         0.627 |         0.758 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           16.064 |              16.704 |          105.120 |             120.768 |         0.962 |         0.870 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           15.552 |              18.144 |          107.136 |             121.696 |         0.857 |         0.880 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           16.096 |              16.768 |          102.688 |             120.864 |         0.960 |         0.850 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 64)      |           16.032 |              16.576 |          104.736 |             124.672 |         0.967 |         0.840 |
| None          | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.392 |              21.952 |          104.736 |             174.656 |         0.883 |         0.600 |
| None          | causal     | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           20.128 |              23.712 |          105.216 |             199.008 |         0.849 |         0.529 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.904 |              21.888 |          103.744 |             179.520 |         0.909 |         0.578 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 512, 2, 512, 128)     |           19.968 |              21.952 |          104.640 |             177.312 |         0.910 |         0.590 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           32.096 |              31.904 |          118.720 |             231.968 |         1.006 |         0.512 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           30.528 |              33.952 |          112.480 |             218.304 |         0.899 |         0.515 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           32.160 |              32.224 |          118.752 |             237.312 |         0.998 |         0.500 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64)    |           32.128 |              32.032 |          118.240 |             233.120 |         1.003 |         0.507 |
| None          | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           41.312 |              61.280 |          177.408 |             350.688 |         0.674 |         0.506 |
| None          | causal     | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           39.552 |              59.360 |          168.832 |             371.488 |         0.666 |         0.454 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           41.984 |              61.696 |          177.376 |             360.416 |         0.680 |         0.492 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128)   |           41.312 |              61.760 |          177.184 |             355.744 |         0.669 |         0.498 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          339.744 |             357.888 |          939.712 |            1665.376 |         0.949 |         0.564 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          212.608 |             248.832 |          633.280 |            1122.848 |         0.854 |         0.564 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          339.712 |             363.232 |          940.448 |            1689.440 |         0.935 |         0.557 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64)    |          341.056 |             355.264 |          940.128 |            1641.152 |         0.960 |         0.573 |
| None          | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          460.736 |             741.024 |         1569.824 |            2559.552 |         0.622 |         0.613 |
| None          | causal     | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          293.856 |             464.192 |         1066.240 |            1840.416 |         0.633 |         0.579 |
| relative_bias | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          460.704 |             753.152 |         1570.112 |            2641.088 |         0.612 |         0.594 |
| head_bias     | None       | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128)   |          460.832 |             745.536 |         1570.144 |            2602.560 |         0.618 |         0.603 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           35.680 |              41.280 |          171.840 |             158.176 |         0.864 |         1.086 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           31.360 |              42.976 |          158.912 |             139.264 |         0.730 |         1.141 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           35.168 |              41.600 |          171.648 |             161.344 |         0.845 |         1.064 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 64)     |           35.136 |              41.152 |          171.808 |             158.336 |         0.854 |         1.085 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.832 |              76.384 |          295.680 |             277.696 |         0.639 |         1.065 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           45.632 |              72.512 |          281.760 |             250.752 |         0.629 |         1.124 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           49.504 |              76.608 |          295.584 |             279.712 |         0.646 |         1.057 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 16, 512, 128)    |           48.864 |              75.904 |          295.456 |             277.568 |         0.644 |         1.064 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           99.392 |             111.232 |          408.640 |             442.656 |         0.894 |         0.923 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           71.392 |              95.168 |          338.784 |             341.760 |         0.750 |         0.991 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |           99.808 |             112.256 |          408.608 |             456.160 |         0.889 |         0.896 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64)   |          100.032 |             110.816 |          408.512 |             444.192 |         0.903 |         0.920 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          135.040 |             226.112 |          726.880 |             774.176 |         0.597 |         0.939 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |           99.904 |             169.696 |          616.448 |             607.104 |         0.589 |         1.015 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          135.488 |             228.384 |          727.776 |             782.368 |         0.593 |         0.930 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128)  |          135.744 |             225.664 |          728.000 |             773.600 |         0.602 |         0.941 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1324.192 |            1387.808 |         3866.944 |            5217.184 |         0.954 |         0.741 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |          738.464 |             832.608 |         2507.392 |            3146.688 |         0.887 |         0.797 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1326.016 |            1404.256 |         3867.872 |            5382.624 |         0.944 |         0.719 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64)   |         1326.144 |            1386.688 |         3867.552 |            5203.264 |         0.956 |         0.743 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1847.488 |            2866.336 |         6612.704 |            8597.696 |         0.645 |         0.769 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1066.592 |            1660.640 |         4357.696 |            5174.016 |         0.642 |         0.842 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1850.464 |            2905.408 |         6616.928 |            8793.280 |         0.637 |         0.752 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128)  |         1848.896 |            2834.720 |         6623.872 |            8637.920 |         0.652 |         0.767 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           36.384 |              38.656 |          150.336 |             182.624 |         0.941 |         0.823 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           31.360 |              38.112 |          137.664 |             171.840 |         0.823 |         0.801 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           36.608 |              39.040 |          150.528 |             183.872 |         0.938 |         0.819 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 64)      |           36.064 |              38.656 |          150.560 |             183.520 |         0.933 |         0.820 |
| None          | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           49.344 |              76.352 |          253.920 |             301.440 |         0.646 |         0.842 |
| None          | causal     | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           46.720 |              65.824 |          239.424 |             296.384 |         0.710 |         0.808 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           49.248 |              76.416 |          253.728 |             307.808 |         0.644 |         0.824 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 512, 2, 512, 128)     |           49.376 |              76.288 |          253.728 |             304.736 |         0.647 |         0.833 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           99.264 |             110.144 |          364.960 |             503.072 |         0.901 |         0.725 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           71.136 |              92.384 |          294.432 |             393.056 |         0.770 |         0.749 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           99.200 |             111.360 |          365.152 |             512.640 |         0.891 |         0.712 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64)    |           99.264 |             110.240 |          365.088 |             504.224 |         0.900 |         0.724 |
| None          | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          135.680 |             230.336 |          613.472 |             816.896 |         0.589 |         0.751 |
| None          | causal     | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          100.256 |             165.088 |          502.144 |             676.480 |         0.607 |         0.742 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          135.008 |             232.480 |          613.184 |             836.672 |         0.581 |         0.733 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128)   |          135.232 |             230.624 |          613.536 |             827.136 |         0.586 |         0.742 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1324.064 |            1378.688 |         3631.808 |            5308.384 |         0.960 |         0.684 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |          731.776 |             826.688 |         2263.168 |            3241.344 |         0.885 |         0.698 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1316.128 |            1403.200 |         3625.088 |            5550.688 |         0.938 |         0.653 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64)    |         1311.904 |            1378.880 |         3616.320 |            5353.696 |         0.951 |         0.675 |
| None          | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1837.856 |            2887.392 |         6121.632 |            8586.656 |         0.637 |         0.713 |
| None          | causal     | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1066.976 |            1654.368 |         3843.136 |            5291.040 |         0.645 |         0.726 |
| relative_bias | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1854.208 |            2896.832 |         6130.112 |            8745.984 |         0.640 |         0.701 |
| head_bias     | None       | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128)   |         1860.512 |            2889.344 |         6135.648 |            8750.592 |         0.644 |         0.701 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           60.640 |              71.552 |          315.968 |             296.512 |         0.847 |         1.066 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           50.784 |              71.040 |          284.288 |             258.880 |         0.715 |         1.098 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           61.312 |              72.704 |          315.680 |             302.016 |         0.843 |         1.045 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 64)    |           60.800 |              71.776 |          316.320 |             297.152 |         0.847 |         1.065 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           84.576 |             144.416 |          580.576 |             535.936 |         0.586 |         1.083 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           76.064 |             123.648 |          553.344 |             481.376 |         0.615 |         1.150 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           84.160 |             145.248 |          581.024 |             540.000 |         0.579 |         1.076 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 16, 512, 128)   |           84.512 |             143.552 |          581.088 |             535.776 |         0.589 |         1.085 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          189.152 |             209.408 |          798.400 |             868.704 |         0.903 |         0.919 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          127.552 |             168.800 |          650.816 |             663.328 |         0.756 |         0.981 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          189.376 |             211.360 |          798.080 |             895.552 |         0.896 |         0.891 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64)  |          189.440 |             208.576 |          797.888 |             873.152 |         0.908 |         0.914 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          257.536 |             441.760 |         1408.960 |            1514.720 |         0.583 |         0.930 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          179.328 |             312.096 |         1170.368 |            1177.472 |         0.575 |         0.994 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          259.264 |             446.944 |         1408.768 |            1530.400 |         0.580 |         0.921 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) |          258.080 |             440.480 |         1408.864 |            1514.144 |         0.586 |         0.930 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2595.808 |            2771.456 |         7616.704 |           10405.248 |         0.937 |         0.732 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         1435.744 |            1610.336 |         4927.520 |            6220.000 |         0.892 |         0.792 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2595.264 |            2745.056 |         7611.232 |           10631.392 |         0.945 |         0.716 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64)  |         2576.256 |            2735.456 |         7626.400 |           10346.976 |         0.942 |         0.737 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3679.744 |            5634.816 |        13077.056 |           17182.528 |         0.653 |         0.761 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         2099.360 |            3250.176 |         8589.664 |           10236.672 |         0.646 |         0.839 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3676.800 |            5716.288 |        13073.088 |           17311.071 |         0.643 |         0.755 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) |         3679.136 |            5570.496 |        13070.720 |           17192.863 |         0.660 |         0.760 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           61.600 |              71.008 |          272.320 |             300.000 |         0.868 |         0.908 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           50.176 |              65.344 |          241.568 |             258.912 |         0.768 |         0.933 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           61.120 |              72.512 |          272.672 |             305.408 |         0.843 |         0.893 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 64)     |           61.248 |              71.136 |          272.640 |             301.120 |         0.861 |         0.905 |
| None          | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           83.872 |             146.784 |          466.912 |             496.832 |         0.571 |         0.940 |
| None          | causal     | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           76.704 |             115.072 |          435.584 |             462.112 |         0.667 |         0.943 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           83.392 |             147.392 |          466.656 |             504.448 |         0.566 |         0.925 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 512, 2, 512, 128)    |           83.360 |             146.688 |          466.656 |             499.040 |         0.568 |         0.935 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          189.024 |             207.584 |          684.768 |             873.568 |         0.911 |         0.784 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          126.944 |             164.288 |          536.192 |             645.984 |         0.773 |         0.830 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          188.768 |             209.760 |          684.096 |             897.504 |         0.900 |         0.762 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64)   |          189.408 |             207.776 |          685.024 |             876.384 |         0.912 |         0.782 |
| None          | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          259.168 |             449.536 |         1167.936 |            1433.280 |         0.577 |         0.815 |
| None          | causal     | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          180.000 |             305.312 |          928.000 |            1113.920 |         0.590 |         0.833 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          258.464 |             455.136 |         1167.808 |            1462.848 |         0.568 |         0.798 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128)  |          257.824 |             450.208 |         1167.744 |            1448.000 |         0.573 |         0.806 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2598.368 |            2729.120 |         7134.400 |           10381.632 |         0.952 |         0.687 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         1435.456 |            1591.040 |         4424.768 |            6035.808 |         0.902 |         0.733 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2594.752 |            2725.952 |         7128.384 |           10822.496 |         0.952 |         0.659 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64)   |         2597.888 |            2716.960 |         7101.568 |           10385.440 |         0.956 |         0.684 |
| None          | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3647.648 |            5581.632 |        12089.952 |           16667.233 |         0.654 |         0.725 |
| None          | causal     | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         2093.952 |            3241.440 |         7579.392 |            9847.936 |         0.646 |         0.770 |
| relative_bias | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3650.528 |            5650.688 |        12105.568 |           16963.680 |         0.646 |         0.714 |
| head_bias     | None       | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128)  |         3680.064 |            5585.312 |        12117.504 |           16935.040 |         0.659 |         0.716 |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505
Approved by: https://github.com/Chillee
2024-09-10 09:30:02 +00:00
chilli
348d02a983 Changed masked out rows logsumexp to be -inf and not zero (#134650)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134650
Approved by: https://github.com/yanboliang, https://github.com/BoyuanFeng, https://github.com/drisspg
2024-08-29 17:22:52 +00:00
drisspg
d966d91e37 [FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538
Approved by: https://github.com/yanboliang
ghstack dependencies: #134507, #134511
2024-08-27 22:04:57 +00:00
PyTorch MergeBot
3e10a1eb5a Revert "[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)"
This reverts commit a34320a6f2.

Reverted https://github.com/pytorch/pytorch/pull/134538 on behalf of https://github.com/albanD due to Broke lint due to too long line ([comment](https://github.com/pytorch/pytorch/pull/134507#issuecomment-2312505955))
2024-08-27 13:05:27 +00:00
drisspg
a34320a6f2 [FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538
Approved by: https://github.com/yanboliang
ghstack dependencies: #134495, #134507, #134511
2024-08-27 09:53:19 +00:00
drisspg
bf5addb613 [FlexAttention] Enable different qk and v head-dims (#134043)
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: https://github.com/pytorch/pytorch/issues/133674

V_DIM > QK_DIM is blocked by landing: https://github.com/triton-lang/triton/pull/4138 / https://github.com/triton-lang/triton/pull/4540

Into PyTorch's triton branch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134043
Approved by: https://github.com/Chillee
2024-08-23 01:06:57 +00:00
drisspg
629bd6f718 Update FlexAttention with masking semantic (#133373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373
Approved by: https://github.com/yanboliang
2024-08-22 22:50:33 +00:00