mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
cef815dc2c
71 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
cef815dc2c |
[ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998)
A few UT failures are caused by `HIPBLASLT_ALLOW_TF32` Fixes #157094, #157093, #157092, #157091, #157064, #157063, #157062, #157061, #157042, #157041, #157039, #157004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
||
|
|
a3d72b09ae |
Apply Triton tensor descriptor for flex-decoding for performance (#161643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161643 Approved by: https://github.com/drisspg |
||
|
|
f36f285953 |
[dynamo] change error_on_graph_break/fullgraph semantics (#161747)
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`: - ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~ - `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks. - `error_on_graph_break` does nothing when `fullgraph=True` - `error_on_graph_break` does NOT guarantee a single graph Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation: - `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled - `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time - `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time Pull Request resolved: https://github.com/pytorch/pytorch/pull/161747 Approved by: https://github.com/mlazos ghstack dependencies: #161739 |
||
|
|
2fed4fb464 |
[FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861)
Fixes #159247 Issue 1: Accuracy Problem with Non-Divisible KV Sequences --------------------------------------------------------- ### Background Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy. ### Root Cause The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue. ### Solution * **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch ### Files Modified * `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod. * `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests * `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests Issue 2: Invalid Memory Access (IMA) in Triton Kernels ------------------------------------------------------ ### Background The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes. ### Root Cause * Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`) * With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values * This caused out-of-bounds memory access violations ### Solution Implemented boundary checks with early exit: 1. **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py` * Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template 2. **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja` * Boundary checks before accessing `kv_indices` in both normal and full blocks * Idle CTAs with invalid `indices_idx` skip computation entirely This prevents invalid memory access while reducing wasted computation on idle thread blocks. Testing & Validation -------------------- ### Accuracy Tests * Added comprehensive test cases covering KV sequences not divisible by block sizes * Verified output matches standard attention for various sequence length combinations ### Sanitizer Results `========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors` **Before**: More than 13720 invalid memory access errors with sanitizers **After**: Clean execution with 0 errors Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861 Approved by: https://github.com/BoyuanFeng |
||
|
|
3e459491b5 |
Enable XPU path for FlexAttention (#143553)
[#RFC153024](https://github.com/pytorch/pytorch/issues/153024) **Motivation** 1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA in the transformers repo on XPU device. Besides, it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device. 2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256. **What does this PR do?** 1. Enable the device type for Flexattention kernel and UTs to ensure all important UTs pass on XPU device. 2. For E2E model inference, ensure the functionality of LLM models inference with FlexAttention to be ready. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143553 Approved by: https://github.com/EikanWang, https://github.com/drisspg Co-authored-by: Mao Yunfei <yunfei.mao@intel.com> Co-authored-by: Xingyuan Li <xingyuan.li@intel.com> Co-authored-by: majing <jing1.ma@intel.com> Co-authored-by: Xiao, Wang <wang.xiao@intel.com> |
||
|
|
8e17709055 |
FlexDecode not guarding on GQA groups correctly (#160904)
Addressing #151359 Updates flex_decode dispatch to use flex attention rather than flex decode if number of groups is not a power of 2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160904 Approved by: https://github.com/drisspg |
||
|
|
21392c0e06 |
[inductor] disable flex decoding on Windows. (#160072)
Discussed with @jianan-gu and @Valentine233 , disable flex decoding on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160072 Approved by: https://github.com/angelayi |
||
|
|
19aa8eb4f5 |
[TF32][Flex Attention] Turn off TF32 for reference computation in test_flex_decoding (#158979)
Seems to avoid threshold (fudge factor) twiddling games as this causes the checks to go down the "very small ref error" path instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158979 Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng, https://github.com/nWEIdia |
||
|
|
561193e5f2 |
[CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)
3 procs were used for sm86, but we switched to sm89 and the check failed so it switched back to 2 sm90 is H100, but idk what unittests we have running there, but I assume they also have a lot of memory They use larger runners, which have more GPU memory, so its usually ok. I think it's ~22GB -> 10GB per proc if 2, 6GB per proc if 3 (cuda context maybe 1GB) I've applied skips to the ones that OOMed Time decreases from ~2.7hr per test job -> ~2hr Pull Request resolved: https://github.com/pytorch/pytorch/pull/158691 Approved by: https://github.com/huydhn |
||
|
|
11ea3736dd |
Revert "[CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)"
This reverts commit
|
||
|
|
0c0fcb53ff |
[CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)
3 procs were used for sm86, but we switched to sm89 and the check failed so it switched back to 2 sm90 is H100, but idk what unittests we have running there, but I assume they also have a lot of memory They use larger runners, which have more GPU memory, so its usually ok. I think it's ~22GB -> 10GB per proc if 2, 6GB per proc if 3 (cuda context maybe 1GB) I've applied skips to the ones that OOMed Time decreases from ~2.7hr per test job -> ~2hr Pull Request resolved: https://github.com/pytorch/pytorch/pull/158691 Approved by: https://github.com/huydhn |
||
|
|
17687eb792 |
[BE][4/6] fix typos in test/ (test/inductor/) (#157638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157638 Approved by: https://github.com/yewentao256, https://github.com/jansel |
||
|
|
1c3f5e902d |
[dynamo] control one_graph behavior additionally through config (#154283)
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283 Approved by: https://github.com/jansel, https://github.com/anijain2305 |
||
|
|
53e0b9c393 |
refine fp32 precision api (#125888)
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it. ### Design Choice: Directly use algorithms name like "TF32", "BF16". #### Pros - The names are more informative. 'tf32' is more informative than a simple "high". - Easier to extend new algorithm like `tf32x3` #### Cons - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them. ### We provide a layered structure for backends/operators. ('f32' is short for 'fp32_precision')  ### We provide 3 fp32 compute precision can be set: - **"ieee"**: Not allowed to use any other internal computation data types . - **"tf32"**: Allowed to use tf32 as internal computation data types. - **"bf16"**: Allowed to use bf16 as internal computation data types. - **"none"**: Precision's are not set. Can be override by its father node. ### Overriding Precision Settings Child node can be override by its father node if it is set to default. For current default settings: ``` backend = generic, op = all, precision setting = none backend = cuda, op = all, precision setting = none backend = cuda, op = conv, precision setting = tf32 backend = cuda, op = rnn, precision setting = tf32 backend = cuda, op = matmul, precision setting = none backend = matmul, op = all, precision setting = none backend = matmul, op = conv, precision setting = none backend = matmul, op = rnn, precision setting = none backend = matmul, op = matmul, precision setting = none ``` - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16". - If the user set `torch.backends.fp32_precision="bf16"`, `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16". ### Backward Compatible Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is - If the user only uses previous APIs, it will work as previous expectations. - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user. ### Test Plan ``` python test/test_cuda.py -k test_fp32_precision_with_tf32 python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision python test/test_cuda.py -k test_invalid_status_for_legacy_api python test/test_mkldnn.py -k test_mlkdnn_get_set python test/test_mkldnn.py -k test_generic_precision python test/test_mkldnn.py -k test_invalid python test/test_mkldnn.py -k test_default_use_parent ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888 Approved by: https://github.com/jgong5, https://github.com/albanD Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com> |
||
|
|
f5e6e52f25 |
[BE][PYFMT] migrate PYFMT for test/inductor/ to ruff format (#148186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148186 Approved by: https://github.com/jansel |
||
|
|
b5c8b8d09f |
Revert "[dynamo] control one_graph behavior additionally through config (#154283)"
This reverts commit
|
||
|
|
b46eb1ccaf |
[dynamo] control one_graph behavior additionally through config (#154283)
`torch.compile` now always goes through `torch._dynamo._optimize`. fullgraph is now implemented in `torch.compile` by looking at `config.error_on_graph_break`. Export still goes through `torch._dynamo._optimize_assert`, which uses `tx.one_graph` instead of `config.error_on_graph_break`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154283 Approved by: https://github.com/jansel, https://github.com/anijain2305 |
||
|
|
ae53510b9e |
Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test Update: this was not as straightforward as I originally imagined. Context below. **TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not. **Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`. Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`. (1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests: |
||
|
|
12cb11a268 |
[Inductor UT] Refactor FlexAttention UT and add CPU tests (#144953)
This PR extends and refines all rest UTs for CPU and more devices in `test/inductor/test_flex_attention.py` and `test/inductor/test_flex_decoding.py`, as a follow-up to https://github.com/pytorch/pytorch/pull/141453 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144953 Approved by: https://github.com/drisspg |
||
|
|
5e9f792479 |
[ROCm] Unskip flex attention UTs after triton 3.3 bump (#148327)
Enable `test_flex_attention.py::TestLearnableBiases` unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148327 Approved by: https://github.com/jeffdaily |
||
|
|
98b3f1db9f |
[Flex Attention] support num_heads > 1 in block_mask (#148857)
Previously flex decoding errors when block mask has num_heads > 1. So users have to use num_heads=1, or explicitly mark `kernel_options={"FORCE_USE_FLEX_ATTENTION": True}`.
This PR fixes this issue. When not using grouped query attention (GQA, i.e., Hq == Hkv), we support block mask with num_heads = 1 and num_heads = num_query_heads (i.e., Hq). This is the same setting as flex attention kernel.
When using GQA (i.e., Hq != Hkv), we support block mask with num_heads = 1. When num_heads = Hq, we fall back to flex attention kernel so user don't need to explicitly mark `kernel_options={"FORCE_USE_FLEX_ATTENTION": True}` anymore.
Why fallback? In the current flex decoding triton kernel, grouped query heads for the same kv head are handled by the same thread block. Supporting num_heads = Hq with GQA requires support different kv num blocks for different query heads in the same thread block, leading to lots of redundant workload. So we should better use the main flex_attention kernel where each query head is handled by a separate block.
Fixes #148527
Fixes #147267
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148857
Approved by: https://github.com/drisspg
|
||
|
|
ba9ed856e0 |
[FlexAttention] Improve error msg for embedding < 16 (#147765)
flex_attention uses tl.dot, which [does not support embedding < 16](https://github.com/triton-lang/triton/issues/2266) on input shapes. This PR adds explicit error message for users who are prototyping with small tensors. Fixes #147701 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147765 Approved by: https://github.com/drisspg |
||
|
|
c6707734de |
Enable non power of 2 head_dim for FlexAttention (#133495)
# Summary - Adds support for non-power of 2 headdim by launching blocks w/ head_dim rounded to the next valid power. - Other option I considered was building up the final dot_products with smaller blocks (this would probably work but for sake of code complexity going with this option for now) ### Corollary We had a bug in our backwards kernel where we were using index_k instead of index_v. This should have shown up for the qk_head_dim != v_head_dim cases.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133495 Approved by: https://github.com/Chillee |
||
|
|
99dbc5b0e2 |
PEP585 update - test (#145176)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93 |
||
|
|
577708e6de |
Unskipped multiple inductor tests for ROCm (#143581)
All of them should be fine to run now after the triton fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143581 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
||
|
|
069419569d |
[PagedAttention] Support different input position for each batch index (#144693)
In LLM inference, each request usually has different prefill length, leading to different input position for each batch index. This PR adds such support for paged attention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144693 Approved by: https://github.com/drisspg |
||
|
|
7d9f26de05 |
Revert "Unskipped multiple inductor tests for ROCm (#143581)"
This reverts commit
|
||
|
|
e05d67790e |
Unskipped multiple inductor tests for ROCm (#143581)
All of them should be fine to run now after the triton fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143581 Approved by: https://github.com/jataylo, https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
||
|
|
d8c8ba2440 |
Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD |
||
|
|
ed9931e6ee |
Add tests for non divisible inputs for flex decoding (#143214)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143214 Approved by: https://github.com/drisspg |
||
|
|
424156c26c |
[ROCm] Update to AOTriton 0.8b (#140172)
Notable new features for SDPA operators on AMD systems from AOTriton 0.8b:
1. Nestedtensor support;
2. MQA/GQA support;
3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases;
+ The kernel should use top-left alignment, bottom right alignment will be added later
4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status.
However, users are strongly recommended to update to ROCM 6.2.4, notably for
its firmware updates.
Related unit tests are enabled as well.
Notable related changes from AOTriton 0.8b:
1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`;
2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB
3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency.
+ Should not be a problem, since `lzma` is part of Python Standard Library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
|
||
|
|
af88326250 |
Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)
Fixes https://github.com/pytorch/pytorch/issues/141435 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625 Approved by: https://github.com/drisspg ghstack dependencies: #138788 |
||
|
|
a34a56f69f |
Revert "Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)"
This reverts commit
|
||
|
|
795f28ac55 |
Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625)
Fixes https://github.com/pytorch/pytorch/issues/141435 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625 Approved by: https://github.com/drisspg ghstack dependencies: #138788 |
||
|
|
cb71bcc542 |
Replace clone.detach with detach.clone (#140264)
Fixes #64532 As state in issue, replace `clone.detach` by `detach.clone` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264 Approved by: https://github.com/soulitzer |
||
|
|
540f3ef9b1 |
Fix flex_decode to build offsets off of strides (#139516)
Fixes PR: https://github.com/pytorch/pytorch/issues/139462 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139516 Approved by: https://github.com/Chillee |
||
|
|
68134a320e |
[Flex Attention] Paged Attention (#137164)
This PR adds paged attention for flex attention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137164 Approved by: https://github.com/drisspg |
||
|
|
34c18887ad |
[FlexAttention] Remove restriction on QK headdim > V headdim (#135884)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135884 Approved by: https://github.com/Chillee |
||
|
|
28224329ad |
[Flex Attention] fix block size order (#136657)
`create_block_mask` currently gives wrong BLOCK_SIZE and shape when using non-default block size `(128,128)`. This PR fixes the issue by using BLOCK_SIZE order `(Q_BLOCK_SIZE, KV_BLOCK_SIZE)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136657 Approved by: https://github.com/Chillee, https://github.com/drisspg |
||
|
|
f7ab0e9989 |
Revert "[Flex Attention] fix block size order (#136657)"
This reverts commit |
||
|
|
b42f1e3641 |
[Flex Attention] fix block size order (#136657)
`create_block_mask` currently gives wrong BLOCK_SIZE and shape when using non-default block size `(128,128)`. This PR fixes the issue by using BLOCK_SIZE order `(Q_BLOCK_SIZE, KV_BLOCK_SIZE)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136657 Approved by: https://github.com/Chillee, https://github.com/drisspg |
||
|
|
ccca3de0cd |
[ROCm] Enable Flex attention tests on AMD gpus (#136245)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136245 Approved by: https://github.com/malfet |
||
|
|
74fd1bf965 |
[ROCm] Update to AOTriton 0.7b (#134498)
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
+ Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`
Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.
Fixes #133540
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
|
||
|
|
6e13f5eb38 |
[FlexAttention] Add broadcast support for kv batch dimension (#135505)
This PR adds broadcast support for KV batch dimension. ## Details Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention. This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward. ## Benchmark GPU: H100 We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`. ``` python benchmarks/transformer/score_mod.py --calculate-bwd ``` ### Perf before this PR **FWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|---------------|------------|----------------|------------------------------| | Average | 0.743 | | | | | | Max | 0.955 | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | | Min | 0.548 | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | **BWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|-------------|------------|----------------|-----------------------------| | Average | 0.834 | | | | | | Max | 1.261 | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | | Min | 0.456 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | <details> <summary> Full performance sweep </summary> | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 107.040 | 140.800 | 0.888 | 0.760 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.840 | 19.744 | 112.576 | 140.064 | 0.802 | 0.804 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.232 | 17.344 | 87.744 | 142.496 | 0.878 | 0.616 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 108.192 | 143.328 | 0.888 | 0.755 | | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.904 | 22.400 | 106.432 | 136.512 | 0.889 | 0.780 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.424 | 26.752 | 91.712 | 106.688 | 0.726 | 0.860 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.808 | 22.432 | 89.024 | 101.920 | 0.883 | 0.873 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.840 | 22.272 | 88.896 | 102.592 | 0.891 | 0.867 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.240 | 32.416 | 116.768 | 112.256 | 0.933 | 1.040 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 29.536 | 37.024 | 113.664 | 102.688 | 0.798 | 1.107 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.656 | 32.800 | 116.992 | 127.008 | 0.935 | 0.921 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.592 | 32.480 | 116.928 | 112.160 | 0.942 | 1.043 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.920 | 198.656 | 204.512 | 0.653 | 0.971 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 37.760 | 62.528 | 189.536 | 170.624 | 0.604 | 1.111 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.896 | 62.368 | 198.304 | 205.824 | 0.656 | 0.963 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.952 | 198.432 | 203.648 | 0.653 | 0.974 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 318.528 | 355.904 | 947.232 | 1162.496 | 0.895 | 0.815 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 199.776 | 252.128 | 677.792 | 813.184 | 0.792 | 0.834 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 316.512 | 363.328 | 947.712 | 1361.984 | 0.871 | 0.696 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 317.984 | 356.864 | 947.264 | 1165.024 | 0.891 | 0.813 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 446.656 | 734.656 | 1664.288 | 2172.960 | 0.608 | 0.766 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 278.688 | 467.648 | 1182.624 | 1339.296 | 0.596 | 0.883 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 447.872 | 744.096 | 1662.944 | 2196.544 | 0.602 | 0.757 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 448.128 | 732.928 | 1663.072 | 2156.800 | 0.611 | 0.771 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.648 | 16.640 | 107.520 | 143.008 | 0.940 | 0.752 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.776 | 18.240 | 129.056 | 141.920 | 0.865 | 0.909 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.168 | 16.640 | 103.616 | 139.648 | 0.912 | 0.742 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.616 | 16.640 | 128.608 | 164.448 | 0.938 | 0.782 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 21.952 | 125.344 | 170.304 | 0.901 | 0.736 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 23.712 | 104.288 | 196.896 | 0.834 | 0.530 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.072 | 21.952 | 102.080 | 177.056 | 0.869 | 0.577 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.648 | 21.920 | 109.920 | 170.848 | 0.896 | 0.643 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.936 | 127.808 | 228.832 | 0.954 | 0.559 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 29.472 | 33.856 | 113.152 | 215.072 | 0.871 | 0.526 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.496 | 32.160 | 116.576 | 231.744 | 0.948 | 0.503 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.904 | 116.320 | 229.824 | 0.955 | 0.506 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.480 | 61.440 | 176.448 | 345.312 | 0.659 | 0.511 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 38.304 | 59.424 | 169.312 | 371.360 | 0.645 | 0.456 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.960 | 61.760 | 176.512 | 358.912 | 0.663 | 0.492 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.352 | 61.696 | 176.512 | 344.928 | 0.654 | 0.512 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.224 | 357.728 | 905.728 | 1668.448 | 0.884 | 0.543 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 199.904 | 248.416 | 636.544 | 1109.088 | 0.805 | 0.574 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 314.880 | 363.616 | 906.304 | 1658.176 | 0.866 | 0.547 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.160 | 354.368 | 906.080 | 1649.024 | 0.892 | 0.549 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.912 | 739.840 | 1555.808 | 2521.952 | 0.604 | 0.617 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 279.776 | 463.904 | 1068.928 | 1849.888 | 0.603 | 0.578 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.080 | 748.960 | 1553.504 | 2629.888 | 0.596 | 0.591 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.208 | 740.608 | 1558.880 | 2524.960 | 0.602 | 0.617 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 33.568 | 41.280 | 170.016 | 147.584 | 0.813 | 1.152 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 30.688 | 43.040 | 159.552 | 146.720 | 0.713 | 1.087 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.112 | 41.504 | 170.112 | 152.672 | 0.822 | 1.114 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.240 | 41.152 | 170.272 | 134.976 | 0.832 | 1.261 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.672 | 76.416 | 295.296 | 263.648 | 0.637 | 1.120 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.088 | 72.576 | 281.920 | 237.664 | 0.621 | 1.186 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.032 | 76.672 | 295.520 | 265.248 | 0.626 | 1.114 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.096 | 76.096 | 295.456 | 262.112 | 0.632 | 1.127 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.920 | 111.232 | 401.568 | 382.944 | 0.844 | 1.049 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 68.192 | 95.232 | 338.752 | 326.816 | 0.716 | 1.037 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.984 | 111.840 | 401.856 | 444.224 | 0.840 | 0.905 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 94.176 | 110.496 | 401.600 | 383.136 | 0.852 | 1.048 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.488 | 227.040 | 727.424 | 739.712 | 0.579 | 0.983 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 95.616 | 169.760 | 616.864 | 574.112 | 0.563 | 1.074 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.680 | 228.672 | 727.616 | 746.048 | 0.576 | 0.975 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.104 | 225.696 | 727.904 | 735.392 | 0.581 | 0.990 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1227.296 | 1386.656 | 3720.192 | 4539.904 | 0.885 | 0.819 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 691.360 | 831.712 | 2515.872 | 3067.808 | 0.831 | 0.820 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1228.192 | 1403.136 | 3715.520 | 5309.280 | 0.875 | 0.700 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1229.024 | 1384.992 | 3715.904 | 4550.368 | 0.887 | 0.817 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1784.832 | 2865.888 | 6539.840 | 8460.224 | 0.623 | 0.773 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1017.408 | 1660.480 | 4369.824 | 5056.992 | 0.613 | 0.864 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1792.448 | 2904.864 | 6546.080 | 8537.024 | 0.617 | 0.767 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1795.552 | 2856.864 | 6544.672 | 8400.160 | 0.629 | 0.779 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.880 | 148.832 | 179.936 | 0.881 | 0.827 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.168 | 38.080 | 138.528 | 167.552 | 0.818 | 0.827 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 39.168 | 148.512 | 181.248 | 0.874 | 0.819 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.784 | 148.864 | 180.224 | 0.883 | 0.826 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.832 | 76.352 | 253.632 | 295.968 | 0.640 | 0.857 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 45.760 | 65.792 | 239.040 | 290.752 | 0.696 | 0.822 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.576 | 253.312 | 304.032 | 0.637 | 0.833 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.192 | 253.600 | 296.096 | 0.640 | 0.856 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.728 | 109.728 | 357.696 | 498.912 | 0.854 | 0.717 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 68.704 | 92.288 | 295.616 | 386.240 | 0.744 | 0.765 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.632 | 111.392 | 357.408 | 512.448 | 0.841 | 0.697 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.280 | 109.952 | 357.696 | 501.440 | 0.848 | 0.713 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.392 | 230.496 | 612.224 | 807.552 | 0.570 | 0.758 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 96.512 | 165.184 | 502.624 | 672.384 | 0.584 | 0.748 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.360 | 232.608 | 612.064 | 832.320 | 0.565 | 0.735 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.008 | 230.528 | 612.640 | 804.320 | 0.568 | 0.762 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1227.968 | 1377.408 | 3477.920 | 5324.384 | 0.892 | 0.653 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 695.264 | 824.544 | 2268.224 | 3210.208 | 0.843 | 0.707 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.640 | 1404.576 | 3476.832 | 5463.456 | 0.875 | 0.636 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.416 | 1378.752 | 3478.048 | 5367.712 | 0.891 | 0.648 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1788.736 | 2867.712 | 6039.520 | 8616.256 | 0.624 | 0.701 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1021.952 | 1653.824 | 3866.208 | 5306.848 | 0.618 | 0.729 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.752 | 2896.352 | 6044.128 | 8871.360 | 0.617 | 0.681 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.080 | 2868.672 | 6040.160 | 8550.144 | 0.623 | 0.706 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.504 | 71.552 | 312.768 | 255.040 | 0.804 | 1.226 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 49.472 | 71.104 | 285.696 | 243.520 | 0.696 | 1.173 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 58.112 | 72.896 | 312.768 | 288.256 | 0.797 | 1.085 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.952 | 71.680 | 312.768 | 255.552 | 0.808 | 1.224 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.336 | 144.256 | 580.128 | 500.160 | 0.571 | 1.160 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.160 | 123.712 | 552.544 | 447.648 | 0.616 | 1.234 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.400 | 145.184 | 580.032 | 504.032 | 0.568 | 1.151 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.368 | 143.904 | 580.192 | 499.936 | 0.572 | 1.161 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.216 | 209.568 | 787.872 | 747.712 | 0.846 | 1.054 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 121.984 | 168.256 | 651.968 | 628.256 | 0.725 | 1.038 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.088 | 211.488 | 788.320 | 864.352 | 0.837 | 0.912 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.440 | 208.576 | 787.424 | 749.120 | 0.851 | 1.051 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.472 | 441.376 | 1405.440 | 1431.648 | 0.565 | 0.982 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 172.960 | 312.064 | 1172.064 | 1096.448 | 0.554 | 1.069 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.632 | 446.336 | 1405.408 | 1448.480 | 0.559 | 0.970 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 250.944 | 440.128 | 1406.624 | 1421.952 | 0.570 | 0.989 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2418.720 | 2747.936 | 7330.432 | 9023.712 | 0.880 | 0.812 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1353.696 | 1608.480 | 4941.696 | 6078.752 | 0.842 | 0.813 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2427.456 | 2746.816 | 7329.792 | 10539.968 | 0.884 | 0.695 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2426.688 | 2763.168 | 7336.256 | 9057.536 | 0.878 | 0.810 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3554.240 | 5634.400 | 12919.872 | 16843.489 | 0.631 | 0.767 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2003.648 | 3250.784 | 8610.144 | 10015.424 | 0.616 | 0.860 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3582.080 | 5710.944 | 12923.328 | 17011.871 | 0.627 | 0.760 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3581.920 | 5618.144 | 12934.528 | 16745.888 | 0.638 | 0.772 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.120 | 71.232 | 269.760 | 295.680 | 0.802 | 0.912 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 49.408 | 65.312 | 242.304 | 253.952 | 0.756 | 0.954 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.504 | 72.544 | 269.632 | 298.976 | 0.793 | 0.902 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.760 | 71.040 | 269.600 | 296.640 | 0.813 | 0.909 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 82.336 | 147.168 | 466.080 | 487.456 | 0.559 | 0.956 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.040 | 435.392 | 453.248 | 0.667 | 0.961 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.856 | 147.424 | 465.920 | 499.552 | 0.555 | 0.933 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.760 | 146.656 | 466.176 | 485.984 | 0.557 | 0.959 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 206.976 | 678.080 | 866.976 | 0.853 | 0.782 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 121.664 | 164.768 | 538.240 | 636.160 | 0.738 | 0.846 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 209.664 | 677.696 | 883.424 | 0.842 | 0.767 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 177.440 | 207.840 | 677.248 | 868.288 | 0.854 | 0.780 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.272 | 449.536 | 1163.424 | 1420.832 | 0.557 | 0.819 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 173.472 | 305.376 | 929.408 | 1104.544 | 0.568 | 0.841 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 249.376 | 454.976 | 1163.648 | 1455.296 | 0.548 | 0.800 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.368 | 450.144 | 1163.520 | 1409.984 | 0.556 | 0.825 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2416.576 | 2726.208 | 6835.520 | 10442.784 | 0.886 | 0.655 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1357.440 | 1590.752 | 4433.664 | 5975.296 | 0.853 | 0.742 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2427.360 | 2747.040 | 6853.056 | 10670.784 | 0.884 | 0.642 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2441.120 | 2718.944 | 6836.640 | 10433.792 | 0.898 | 0.655 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3555.392 | 5620.960 | 11944.000 | 16504.801 | 0.633 | 0.724 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2010.848 | 3241.152 | 7636.064 | 9870.464 | 0.620 | 0.774 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3557.440 | 5688.352 | 11935.744 | 17090.496 | 0.625 | 0.698 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3562.720 | 5630.432 | 11939.168 | 16392.033 | 0.633 | 0.728 | </details> ### Perf after this PR **FWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|---------------|------------|----------------|----------------------------| | Average | 0.776 | | | | | | Max | 1.006 | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | | Min | 0.566 | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | **BWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|-------------|------------|----------------|-----------------------------| | Average | 0.817 | | | | | | Max | 1.150 | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | | Min | 0.454 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | <details> <summary> Full performance sweep </summary> | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.680 | 17.056 | 64.544 | 73.376 | 0.919 | 0.880 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.712 | 19.872 | 65.408 | 72.864 | 0.791 | 0.898 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.160 | 17.280 | 64.896 | 73.888 | 0.935 | 0.878 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.192 | 17.120 | 64.896 | 75.424 | 0.946 | 0.860 | | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.648 | 22.496 | 89.184 | 82.592 | 0.873 | 1.080 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.320 | 26.816 | 91.264 | 82.880 | 0.758 | 1.101 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.096 | 22.528 | 89.184 | 83.776 | 0.892 | 1.065 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.680 | 22.432 | 89.184 | 120.096 | 0.877 | 0.743 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.384 | 32.512 | 119.232 | 128.960 | 0.996 | 0.925 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.176 | 37.248 | 113.664 | 119.520 | 0.810 | 0.951 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.512 | 32.928 | 119.264 | 131.456 | 0.987 | 0.907 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.448 | 32.704 | 119.200 | 128.352 | 0.992 | 0.929 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.952 | 62.176 | 199.040 | 214.304 | 0.675 | 0.929 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 39.744 | 62.880 | 189.504 | 179.968 | 0.632 | 1.053 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.472 | 62.784 | 199.136 | 217.664 | 0.661 | 0.915 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 42.048 | 61.952 | 199.168 | 214.496 | 0.679 | 0.929 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 341.184 | 357.632 | 980.256 | 1328.896 | 0.954 | 0.738 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 212.576 | 252.960 | 673.888 | 824.864 | 0.840 | 0.817 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.000 | 363.296 | 980.768 | 1375.808 | 0.936 | 0.713 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.768 | 356.832 | 980.960 | 1326.272 | 0.955 | 0.740 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 459.392 | 737.120 | 1678.240 | 2205.248 | 0.623 | 0.761 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 292.672 | 468.096 | 1178.016 | 1371.584 | 0.625 | 0.859 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.144 | 745.312 | 1680.000 | 2252.512 | 0.620 | 0.746 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.112 | 736.576 | 1679.008 | 2216.480 | 0.627 | 0.758 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.064 | 16.704 | 105.120 | 120.768 | 0.962 | 0.870 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.552 | 18.144 | 107.136 | 121.696 | 0.857 | 0.880 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.096 | 16.768 | 102.688 | 120.864 | 0.960 | 0.850 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.032 | 16.576 | 104.736 | 124.672 | 0.967 | 0.840 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.392 | 21.952 | 104.736 | 174.656 | 0.883 | 0.600 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 20.128 | 23.712 | 105.216 | 199.008 | 0.849 | 0.529 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.904 | 21.888 | 103.744 | 179.520 | 0.909 | 0.578 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.968 | 21.952 | 104.640 | 177.312 | 0.910 | 0.590 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.096 | 31.904 | 118.720 | 231.968 | 1.006 | 0.512 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.528 | 33.952 | 112.480 | 218.304 | 0.899 | 0.515 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.160 | 32.224 | 118.752 | 237.312 | 0.998 | 0.500 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.128 | 32.032 | 118.240 | 233.120 | 1.003 | 0.507 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.280 | 177.408 | 350.688 | 0.674 | 0.506 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 39.552 | 59.360 | 168.832 | 371.488 | 0.666 | 0.454 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.984 | 61.696 | 177.376 | 360.416 | 0.680 | 0.492 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.760 | 177.184 | 355.744 | 0.669 | 0.498 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.744 | 357.888 | 939.712 | 1665.376 | 0.949 | 0.564 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 212.608 | 248.832 | 633.280 | 1122.848 | 0.854 | 0.564 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.712 | 363.232 | 940.448 | 1689.440 | 0.935 | 0.557 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 341.056 | 355.264 | 940.128 | 1641.152 | 0.960 | 0.573 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.736 | 741.024 | 1569.824 | 2559.552 | 0.622 | 0.613 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 293.856 | 464.192 | 1066.240 | 1840.416 | 0.633 | 0.579 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.704 | 753.152 | 1570.112 | 2641.088 | 0.612 | 0.594 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.832 | 745.536 | 1570.144 | 2602.560 | 0.618 | 0.603 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.680 | 41.280 | 171.840 | 158.176 | 0.864 | 1.086 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 31.360 | 42.976 | 158.912 | 139.264 | 0.730 | 1.141 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.168 | 41.600 | 171.648 | 161.344 | 0.845 | 1.064 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.136 | 41.152 | 171.808 | 158.336 | 0.854 | 1.085 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.832 | 76.384 | 295.680 | 277.696 | 0.639 | 1.065 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.632 | 72.512 | 281.760 | 250.752 | 0.629 | 1.124 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 49.504 | 76.608 | 295.584 | 279.712 | 0.646 | 1.057 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.864 | 75.904 | 295.456 | 277.568 | 0.644 | 1.064 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.392 | 111.232 | 408.640 | 442.656 | 0.894 | 0.923 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 71.392 | 95.168 | 338.784 | 341.760 | 0.750 | 0.991 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.808 | 112.256 | 408.608 | 456.160 | 0.889 | 0.896 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 100.032 | 110.816 | 408.512 | 444.192 | 0.903 | 0.920 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.040 | 226.112 | 726.880 | 774.176 | 0.597 | 0.939 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 99.904 | 169.696 | 616.448 | 607.104 | 0.589 | 1.015 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.488 | 228.384 | 727.776 | 782.368 | 0.593 | 0.930 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.744 | 225.664 | 728.000 | 773.600 | 0.602 | 0.941 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1324.192 | 1387.808 | 3866.944 | 5217.184 | 0.954 | 0.741 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 738.464 | 832.608 | 2507.392 | 3146.688 | 0.887 | 0.797 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.016 | 1404.256 | 3867.872 | 5382.624 | 0.944 | 0.719 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.144 | 1386.688 | 3867.552 | 5203.264 | 0.956 | 0.743 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1847.488 | 2866.336 | 6612.704 | 8597.696 | 0.645 | 0.769 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1066.592 | 1660.640 | 4357.696 | 5174.016 | 0.642 | 0.842 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1850.464 | 2905.408 | 6616.928 | 8793.280 | 0.637 | 0.752 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1848.896 | 2834.720 | 6623.872 | 8637.920 | 0.652 | 0.767 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.384 | 38.656 | 150.336 | 182.624 | 0.941 | 0.823 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.360 | 38.112 | 137.664 | 171.840 | 0.823 | 0.801 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.608 | 39.040 | 150.528 | 183.872 | 0.938 | 0.819 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.064 | 38.656 | 150.560 | 183.520 | 0.933 | 0.820 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.344 | 76.352 | 253.920 | 301.440 | 0.646 | 0.842 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 46.720 | 65.824 | 239.424 | 296.384 | 0.710 | 0.808 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.248 | 76.416 | 253.728 | 307.808 | 0.644 | 0.824 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.376 | 76.288 | 253.728 | 304.736 | 0.647 | 0.833 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.144 | 364.960 | 503.072 | 0.901 | 0.725 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 71.136 | 92.384 | 294.432 | 393.056 | 0.770 | 0.749 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.200 | 111.360 | 365.152 | 512.640 | 0.891 | 0.712 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.240 | 365.088 | 504.224 | 0.900 | 0.724 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.680 | 230.336 | 613.472 | 816.896 | 0.589 | 0.751 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 100.256 | 165.088 | 502.144 | 676.480 | 0.607 | 0.742 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.008 | 232.480 | 613.184 | 836.672 | 0.581 | 0.733 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.232 | 230.624 | 613.536 | 827.136 | 0.586 | 0.742 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1324.064 | 1378.688 | 3631.808 | 5308.384 | 0.960 | 0.684 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 731.776 | 826.688 | 2263.168 | 3241.344 | 0.885 | 0.698 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1316.128 | 1403.200 | 3625.088 | 5550.688 | 0.938 | 0.653 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1311.904 | 1378.880 | 3616.320 | 5353.696 | 0.951 | 0.675 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1837.856 | 2887.392 | 6121.632 | 8586.656 | 0.637 | 0.713 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1066.976 | 1654.368 | 3843.136 | 5291.040 | 0.645 | 0.726 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1854.208 | 2896.832 | 6130.112 | 8745.984 | 0.640 | 0.701 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1860.512 | 2889.344 | 6135.648 | 8750.592 | 0.644 | 0.701 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.640 | 71.552 | 315.968 | 296.512 | 0.847 | 1.066 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 50.784 | 71.040 | 284.288 | 258.880 | 0.715 | 1.098 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 61.312 | 72.704 | 315.680 | 302.016 | 0.843 | 1.045 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.800 | 71.776 | 316.320 | 297.152 | 0.847 | 1.065 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.576 | 144.416 | 580.576 | 535.936 | 0.586 | 1.083 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.064 | 123.648 | 553.344 | 481.376 | 0.615 | 1.150 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.160 | 145.248 | 581.024 | 540.000 | 0.579 | 1.076 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.512 | 143.552 | 581.088 | 535.776 | 0.589 | 1.085 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.152 | 209.408 | 798.400 | 868.704 | 0.903 | 0.919 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 127.552 | 168.800 | 650.816 | 663.328 | 0.756 | 0.981 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.376 | 211.360 | 798.080 | 895.552 | 0.896 | 0.891 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.440 | 208.576 | 797.888 | 873.152 | 0.908 | 0.914 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 257.536 | 441.760 | 1408.960 | 1514.720 | 0.583 | 0.930 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 179.328 | 312.096 | 1170.368 | 1177.472 | 0.575 | 0.994 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 259.264 | 446.944 | 1408.768 | 1530.400 | 0.580 | 0.921 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 258.080 | 440.480 | 1408.864 | 1514.144 | 0.586 | 0.930 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.808 | 2771.456 | 7616.704 | 10405.248 | 0.937 | 0.732 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1435.744 | 1610.336 | 4927.520 | 6220.000 | 0.892 | 0.792 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.264 | 2745.056 | 7611.232 | 10631.392 | 0.945 | 0.716 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2576.256 | 2735.456 | 7626.400 | 10346.976 | 0.942 | 0.737 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.744 | 5634.816 | 13077.056 | 17182.528 | 0.653 | 0.761 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2099.360 | 3250.176 | 8589.664 | 10236.672 | 0.646 | 0.839 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3676.800 | 5716.288 | 13073.088 | 17311.071 | 0.643 | 0.755 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.136 | 5570.496 | 13070.720 | 17192.863 | 0.660 | 0.760 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.600 | 71.008 | 272.320 | 300.000 | 0.868 | 0.908 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 50.176 | 65.344 | 241.568 | 258.912 | 0.768 | 0.933 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.120 | 72.512 | 272.672 | 305.408 | 0.843 | 0.893 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.248 | 71.136 | 272.640 | 301.120 | 0.861 | 0.905 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.872 | 146.784 | 466.912 | 496.832 | 0.571 | 0.940 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.072 | 435.584 | 462.112 | 0.667 | 0.943 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.392 | 147.392 | 466.656 | 504.448 | 0.566 | 0.925 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.360 | 146.688 | 466.656 | 499.040 | 0.568 | 0.935 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.024 | 207.584 | 684.768 | 873.568 | 0.911 | 0.784 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 126.944 | 164.288 | 536.192 | 645.984 | 0.773 | 0.830 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 188.768 | 209.760 | 684.096 | 897.504 | 0.900 | 0.762 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.408 | 207.776 | 685.024 | 876.384 | 0.912 | 0.782 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 259.168 | 449.536 | 1167.936 | 1433.280 | 0.577 | 0.815 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 180.000 | 305.312 | 928.000 | 1113.920 | 0.590 | 0.833 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 258.464 | 455.136 | 1167.808 | 1462.848 | 0.568 | 0.798 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 257.824 | 450.208 | 1167.744 | 1448.000 | 0.573 | 0.806 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2598.368 | 2729.120 | 7134.400 | 10381.632 | 0.952 | 0.687 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1435.456 | 1591.040 | 4424.768 | 6035.808 | 0.902 | 0.733 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2594.752 | 2725.952 | 7128.384 | 10822.496 | 0.952 | 0.659 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2597.888 | 2716.960 | 7101.568 | 10385.440 | 0.956 | 0.684 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3647.648 | 5581.632 | 12089.952 | 16667.233 | 0.654 | 0.725 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2093.952 | 3241.440 | 7579.392 | 9847.936 | 0.646 | 0.770 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3650.528 | 5650.688 | 12105.568 | 16963.680 | 0.646 | 0.714 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3680.064 | 5585.312 | 12117.504 | 16935.040 | 0.659 | 0.716 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505 Approved by: https://github.com/Chillee |
||
|
|
348d02a983 |
Changed masked out rows logsumexp to be -inf and not zero (#134650)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134650 Approved by: https://github.com/yanboliang, https://github.com/BoyuanFeng, https://github.com/drisspg |
||
|
|
d966d91e37 |
[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538 Approved by: https://github.com/yanboliang ghstack dependencies: #134507, #134511 |
||
|
|
3e10a1eb5a |
Revert "[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)"
This reverts commit
|
||
|
|
a34320a6f2 |
[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538 Approved by: https://github.com/yanboliang ghstack dependencies: #134495, #134507, #134511 |
||
|
|
bf5addb613 |
[FlexAttention] Enable different qk and v head-dims (#134043)
# Summary Adds the option for the head dims to be different between QK and V tensors. Fixes issue: https://github.com/pytorch/pytorch/issues/133674 V_DIM > QK_DIM is blocked by landing: https://github.com/triton-lang/triton/pull/4138 / https://github.com/triton-lang/triton/pull/4540 Into PyTorch's triton branch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134043 Approved by: https://github.com/Chillee |
||
|
|
629bd6f718 |
Update FlexAttention with masking semantic (#133373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133373 Approved by: https://github.com/yanboliang |