mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
master
282 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
695cb0d342 |
[2/N][Fix] Fix typo in test folder (#166374)
Fix typo in test folder. _typos.toml ```bash [default.extend-words] nd = "nd" arange = "arange" Nd = "Nd" GLOBALs = "GLOBALs" hte = "hte" iy = "iy" PN = "PN" Dout = "Dout" optin = "optin" gam = "gam" PTD = "PTD" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166374 Approved by: https://github.com/cyyever, https://github.com/ezyang |
||
|
|
5016e7b2eb |
[FlexAttention] Add mechanism to get optimal autotune decision (#165817)
Script: https://github.com/meta-pytorch/attention-gym/pull/169 Feels directionally okay but there is some bike shedding / this could be quite prone to collision of keys depending on mask mod and score mod changes and simple cache key. Usecase: https://github.com/meta-pytorch/attention-gym/pull/169 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165817 Approved by: https://github.com/Chillee |
||
|
|
904abfc2ca |
Export flex attention with kwargs and DTensor (#166045)
Fixes #165948 Adding registration of the MaskBlock makes flex attention with kwargs exportable. Also modified unittests to accept kwargs ``` python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export python test/inductor/test_flex_attention.py -k test_pytree_ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166045 Approved by: https://github.com/drisspg, https://github.com/SherlockNoMad Co-authored-by: fduwjj <fduwjj@gmail.com> |
||
|
|
e4c01011c2 |
Mark FlexAttentionBackward as cacheable (#165996)
This probably should have been marked cacheable a long time ago, no reason that it isn't. Test Plan: New regional inductor tests for test_flex_attention now are serializable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165996 Approved by: https://github.com/oulgen, https://github.com/zou3519, https://github.com/drisspg |
||
|
|
516e58965a |
Revert "Export flex attention with kwargs and DTensor (#166045)"
This reverts commit |
||
|
|
de7fdfe41a |
Export flex attention with kwargs and DTensor (#166045)
Fixes #165948 Adding registration of the MaskBlock makes flex attention with kwargs exportable. Also modified unittests to accept kwargs ``` python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export python test/inductor/test_flex_attention.py -k test_pytree_ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166045 Approved by: https://github.com/drisspg |
||
|
|
767199fd9b |
[flex_attention] replace sliced BlockMask noop with helpful error (#164702)
Fixes part of #163314 After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`. Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702 Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng |
||
|
|
6b80c94901 |
[FlexAttention] Fix dynamic shaped heads flex_flash check (#165866)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866 Approved by: https://github.com/BoyuanFeng ghstack dependencies: #165729 |
||
|
|
e925dfcc6b |
Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645 Approved by: https://github.com/ezyang, https://github.com/mlazos |
||
|
|
ac529df244 |
Native matmul (#157743)
### Implementation of #151705 This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates. To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705: 1. **Basic support** (this PR) 2. **Lazy broadcasting** for optimal performance (future PR) ### Summary of This PR This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead. ### Notable Changes 1. Adds a new config flag: `config.triton.enable_native_matmul` 2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled 3. Enforces tililng suitable for matmul when the native matmul flag is enabled 4. Implements code generation for `ops.dot` 5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this. @eellison @jansel @PaulZhang12 @shunting314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743 Approved by: https://github.com/jansel |
||
|
|
a7fa1a91e3 |
fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317 Approved by: https://github.com/drisspg ghstack dependencies: #163986 |
||
|
|
20082d7136 |
Revert "fix flex attention eager bwd: more rounding (#164317)"
This reverts commit |
||
|
|
41808b2ba9 |
fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317 Approved by: https://github.com/drisspg ghstack dependencies: #163986 |
||
|
|
5d7360bb03 |
Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit
|
||
|
|
321e602692 |
Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645 Approved by: https://github.com/ezyang |
||
|
|
91c4db76cb |
fix flex attention eager: dont round down scores to low-precision (closes #163588) (#163986)
Fixes: https://github.com/pytorch/pytorch/issues/163588 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163986 Approved by: https://github.com/drisspg, https://github.com/mlazos |
||
|
|
cfd46d13e6 |
Fix SAC + Flex issue (#164421)
# Summary This happends when flex_attention is not tagged with the ` CheckpointPolicy.MUST_SAVE` policy. This causes the lse to be unrealized. I think in general this probably not the best policy but we shoudn't error Pull Request resolved: https://github.com/pytorch/pytorch/pull/164421 Approved by: https://github.com/Skylion007 |
||
|
|
a8c528c105 |
[1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947 Approved by: https://github.com/ezyang |
||
|
|
e2ce79e4cc |
[Flex] Fix silent correctness w/ backpropping grads (#163677)
Fixes #https://github.com/pytorch/pytorch/issues/162228 # Summary Majority of our tests are only compiling flex-attention in isolation. This means that for fake tensor propagation the input primals and all captured buffers dont do any intermediate computation below autograd. As a result result the by happen chance match the `require_grad`ness of the eager implementation and this check will pass. However if score_mod is a the result of some other intermediate fake tensor prop then it is not guaranteed to have accurate req_gradness, which was happening here. TLDR is that this was a boot and suspenders that was actually harmful and we should just let the joint graph handle creating the correct joint graph Pull Request resolved: https://github.com/pytorch/pytorch/pull/163677 Approved by: https://github.com/ydwu4 |
||
|
|
ed84e808f0 |
[inductor] Freeze layouts in FlexAttention (#163434)
Fixes #163300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163434 Approved by: https://github.com/drisspg ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419 |
||
|
|
1a42656d6c |
[Flex attention] Fix flex attention head broadcast (#163426)
Fixes part of #163314 In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results** This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting. The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error Pull Request resolved: https://github.com/pytorch/pytorch/pull/163426 Approved by: https://github.com/drisspg |
||
|
|
ff6870d134 |
[BE][flex attention] compute RMSE in float64 (#162088)
I saw a failure where the reference error was 0.0, and the compiled error was 0.035. Although the failure still occurs with or without this change, it was confusing to see RMSE of 0.0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162088 Approved by: https://github.com/drisspg |
||
|
|
864ffe12d7 |
Fix some edge cases (#162295)
``` Summary 🔝 Top 5 Performance Differences (by absolute %): shape: (5, 7) ┌────────────────┬────────────────┬─────────────────────────────┬───────────────────┬──────────────────────┬───────────────────────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops BWD (base) ┆ TFlops BWD (no_peel) ┆ no_peel_speedup_over_base ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞════════════════╪════════════════╪═════════════════════════════╪═══════════════════╪══════════════════════╪═══════════════════════════╪═══════════╡ │ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64) ┆ 56.937931 ┆ 58.960459 ┆ 1.035522 ┆ 3.552163 │ │ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 128) ┆ 89.221306 ┆ 86.295642 ┆ 0.967209 ┆ -3.27911 │ │ causal ┆ torch.bfloat16 ┆ (2, 16, 4096, 4, 4096, 128) ┆ 111.552594 ┆ 114.380841 ┆ 1.025353 ┆ 2.535349 │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, 1024, 64) ┆ 74.830149 ┆ 76.685445 ┆ 1.024793 ┆ 2.479344 │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64) ┆ 55.279932 ┆ 56.369312 ┆ 1.019707 ┆ 1.97066 │ └────────────────┴────────────────┴─────────────────────────────┴───────────────────┴──────────────────────┴───────────────────────────┴───────────┘ 🔺 Top 5 Cases Where no_peel (change) is Faster than base (baseline): shape: (5, 7) ┌────────────────┬────────────────┬─────────────────────────────┬───────────────────┬──────────────────────┬───────────────────────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops BWD (base) ┆ TFlops BWD (no_peel) ┆ no_peel_speedup_over_base ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞════════════════╪════════════════╪═════════════════════════════╪═══════════════════╪══════════════════════╪═══════════════════════════╪═══════════╡ │ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64) ┆ 56.937931 ┆ 58.960459 ┆ 1.035522 ┆ 3.552163 │ │ causal ┆ torch.bfloat16 ┆ (2, 16, 4096, 4, 4096, 128) ┆ 111.552594 ┆ 114.380841 ┆ 1.025353 ┆ 2.535349 │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, 1024, 64) ┆ 74.830149 ┆ 76.685445 ┆ 1.024793 ┆ 2.479344 │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64) ┆ 55.279932 ┆ 56.369312 ┆ 1.019707 ┆ 1.97066 │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 4096, 4, 4096, 64) ┆ 111.08814 ┆ 112.447047 ┆ 1.012233 ┆ 1.22327 │ └────────────────┴────────────────┴─────────────────────────────┴───────────────────┴──────────────────────┴───────────────────────────┴───────────┘ 🔻 Top 5 Cases Where no_peel (change) is Slower than base (baseline): shape: (5, 7) ┌────────────────┬────────────────┬─────────────────────────────┬───────────────────┬──────────────────────┬───────────────────────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops BWD (base) ┆ TFlops BWD (no_peel) ┆ no_peel_speedup_over_base ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞════════════════╪════════════════╪═════════════════════════════╪═══════════════════╪══════════════════════╪═══════════════════════════╪═══════════╡ │ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 128) ┆ 89.221306 ┆ 86.295642 ┆ 0.967209 ┆ -3.27911 │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 4, 1024, 64) ┆ 78.23082 ┆ 76.693169 ┆ 0.980345 ┆ -1.965531 │ │ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 4, 2048, 128) ┆ 96.95663 ┆ 95.573333 ┆ 0.985733 ┆ -1.426717 │ │ alibi ┆ torch.bfloat16 ┆ (4, 16, 2048, 4, 2048, 64) ┆ 93.373473 ┆ 92.294147 ┆ 0.988441 ┆ -1.155924 │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 2048, 4, 2048, 128) ┆ 96.95147 ┆ 96.105389 ┆ 0.991273 ┆ -0.872685 │ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162295 Approved by: https://github.com/mlazos, https://github.com/v0i0 |
||
|
|
833997a6fd |
[Inductor][UT] Fix flex attention related inductor cases (#162450)
## Motivation Fixes #162435, Fixes #162436 UT failures: * https://github.com/pytorch/pytorch/actions/runs/17523991468/job/49772651636 * https://github.com/pytorch/pytorch/actions/runs/17523991468/job/49772651637 To fix flex attention related cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162450 Approved by: https://github.com/drisspg |
||
|
|
ac9ccd0dc2 |
Add return-max-scores to flex-attention (#161667)
# Summary
### Update
API
```Py
class AuxRequest(NamedTuple):
"""Request which auxiliary outputs to compute from flex_attention.
Each field is a boolean indicating whether that auxiliary output should be computed.
"""
lse: bool = False
max_scores: bool = False
class AuxOutput(NamedTuple):
"""Auxiliary outputs from flex_attention operation.
Fields will be None if not requested, or contain the tensor if requested.
"""
lse: Optional[Tensor] = None
max_scores: Optional[Tensor] = None
out_only = flex_attention(query, key, value, score_mod)
out_max, aux_max = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(max_scores=True),
)
out_both, aux_both = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
)
```
Returns the max post mod scores from flex attention.
Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.
Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now
Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.
We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.
### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible but since max is essentially 1 hot on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).
For now no grad, we can re-visit if needed.
## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.
```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 161.031318 ┆ 158.597808 ┆ 2.43351 ┆ 1.534391 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘
🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, ┆ 175.546923 ┆ 177.81205 ┆ -2.265127 ┆ -1.273888 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, ┆ 156.282597 ┆ 158.209134 ┆ -1.926537 ┆ -1.217715 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16, ┆ 232.542929 ┆ 235.140136 ┆ -2.597207 ┆ -1.104536 │
│ ┆ ┆ 2048, 128) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 169.652791 ┆ 171.475986 ┆ -1.823195 ┆ -1.063236 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
|
||
|
|
104f2680e0 |
Revert "Add return-max-scores to flex-attention (#161667)"
This reverts commit
|
||
|
|
a3e5466002 |
Revert "Resize to 0 if not going to be used (#161730)"
This reverts commit |
||
|
|
081cab0454 |
Resize to 0 if not going to be used (#161730)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #161730 * #161667 ```Py with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((2, 32, 1024), (32768, 1024, 1), torch.float32) buf1 = empty_strided_cuda((2, 32, 1024), (32768, 1024, 1), torch.float32) buf2 = empty_strided_cuda((2, 32, 1024, 64), ( |
||
|
|
486b20b73c |
Add return-max-scores to flex-attention (#161667)
# Summary
### Update
API
```Py
class AuxRequest(NamedTuple):
"""Request which auxiliary outputs to compute from flex_attention.
Each field is a boolean indicating whether that auxiliary output should be computed.
"""
lse: bool = False
max_scores: bool = False
class AuxOutput(NamedTuple):
"""Auxiliary outputs from flex_attention operation.
Fields will be None if not requested, or contain the tensor if requested.
"""
lse: Optional[Tensor] = None
max_scores: Optional[Tensor] = None
out_only = flex_attention(query, key, value, score_mod)
out_max, aux_max = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(max_scores=True),
)
out_both, aux_both = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
)
```
Returns the max post mod scores from flex attention.
Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.
Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now
Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.
We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.
### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible but since max is essentially 1 hot on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).
For now no grad, we can re-visit if needed.
## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.
```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 161.031318 ┆ 158.597808 ┆ 2.43351 ┆ 1.534391 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘
🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, ┆ 175.546923 ┆ 177.81205 ┆ -2.265127 ┆ -1.273888 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, ┆ 156.282597 ┆ 158.209134 ┆ -1.926537 ┆ -1.217715 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16, ┆ 232.542929 ┆ 235.140136 ┆ -2.597207 ┆ -1.104536 │
│ ┆ ┆ 2048, 128) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 169.652791 ┆ 171.475986 ┆ -1.823195 ┆ -1.063236 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
|
||
|
|
2fed4fb464 |
[FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861)
Fixes #159247 Issue 1: Accuracy Problem with Non-Divisible KV Sequences --------------------------------------------------------- ### Background Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy. ### Root Cause The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue. ### Solution * **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch ### Files Modified * `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod. * `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests * `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests Issue 2: Invalid Memory Access (IMA) in Triton Kernels ------------------------------------------------------ ### Background The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes. ### Root Cause * Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`) * With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values * This caused out-of-bounds memory access violations ### Solution Implemented boundary checks with early exit: 1. **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py` * Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template 2. **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja` * Boundary checks before accessing `kv_indices` in both normal and full blocks * Idle CTAs with invalid `indices_idx` skip computation entirely This prevents invalid memory access while reducing wasted computation on idle thread blocks. Testing & Validation -------------------- ### Accuracy Tests * Added comprehensive test cases covering KV sequences not divisible by block sizes * Verified output matches standard attention for various sequence length combinations ### Sanitizer Results `========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors` **Before**: More than 13720 invalid memory access errors with sanitizers **After**: Clean execution with 0 errors Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861 Approved by: https://github.com/BoyuanFeng |
||
|
|
3e459491b5 |
Enable XPU path for FlexAttention (#143553)
[#RFC153024](https://github.com/pytorch/pytorch/issues/153024) **Motivation** 1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA in the transformers repo on XPU device. Besides, it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device. 2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256. **What does this PR do?** 1. Enable the device type for Flexattention kernel and UTs to ensure all important UTs pass on XPU device. 2. For E2E model inference, ensure the functionality of LLM models inference with FlexAttention to be ready. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143553 Approved by: https://github.com/EikanWang, https://github.com/drisspg Co-authored-by: Mao Yunfei <yunfei.mao@intel.com> Co-authored-by: Xingyuan Li <xingyuan.li@intel.com> Co-authored-by: majing <jing1.ma@intel.com> Co-authored-by: Xiao, Wang <wang.xiao@intel.com> |
||
|
|
ef0483d74c |
Revert "Ensure large tensor int32 -> int64 indexing is enabled (#157767)"
This reverts commit
|
||
|
|
5432966253 |
Revert "Remove test since it ooms on CI (#161644)"
This reverts commit
|
||
|
|
443452ca2f |
Remove test since it ooms on CI (#161644)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161644 Approved by: https://github.com/BoyuanFeng |
||
|
|
b36a20d368 |
Ensure large tensor int32 -> int64 indexing is enabled (#157767)
Fixes: #https://github.com/pytorch/pytorch/issues/157446 I think that this delta is worth the switch form block-ptrs especially since they are deprecated ## Perf Summary A is nightly B is this diff, so `negative` means this diff improves perf TOP 5 differences <img width="805" height="754" alt="Screenshot 2025-08-24 at 5 49 49 PM" src="https://github.com/user-attachments/assets/aa359cdf-ee9a-427d-be72-1b9aef6f3115" /> <details> <summary><strong>Full perf table (click to expand)</strong></summary> | attn_type | dtype | shape(B,Hq,M,Hkv,N,D) | TFlops Version A | TFlops Version B | | --- | --- | --- | --- | --- | | noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 258.38834144791923 | 258.6353685004612 | | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.2192450677751 | 140.12393320464972 | | alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 122.32683823617003 | 118.51603755647925 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.48556906165314 | 137.24259849208627 | | document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 86.59814488695922 | 84.59431398586257 | | noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 288.52679758135764 | 292.9174195871856 | | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 172.25541683643277 | 172.94326459828508 | | alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 164.40864610599826 | 165.035129576335 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 176.54876886433945 | 175.08057670028145 | | document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 125.22491679812626 | 121.06201152859151 | | noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 339.11952481874283 | 339.0132835601695 | | causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 227.58583240284406 | 228.21824999409597 | | alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 185.98569659868966 | 182.32850843255093 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 188.9495725191772 | 180.31385312481657 | | document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 106.25789530994302 | 106.55084959448476 | | noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 357.6430536888533 | 363.30843452247274 | | causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 262.3241154406613 | 265.73250045488 | | alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 249.30498953911416 | 249.35928192833785 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 224.74126243851808 | 223.71776504077988 | | document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 168.26977014013707 | 165.47991483333809 | | noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 382.8178701785897 | 384.34752965862685 | | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 308.1449710013853 | 311.0653716044644 | | alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 251.96365252505072 | 243.92283557225903 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 226.69316232745368 | 215.22769268913356 | | document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 153.34142545296405 | 151.9312673939401 | | noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 396.0998000753126 | 398.35036286102473 | | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 333.5198415274966 | 344.6354466169716 | | alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 310.5955933379696 | 305.66347819546 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 260.4012412689896 | 259.758666997307 | | document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 234.13034252182635 | 227.61676497283614 | | noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 396.17615538477196 | 401.1419104525502 | | causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 359.98648311998414 | 360.8285563463094 | | alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 291.97720707257736 | 281.41694809965253 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 250.1703628419691 | 238.556760291579 | | document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 199.50782826294306 | 191.52327358439223 | | noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 411.0632004785396 | 413.6362648405517 | | causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 382.9404387613185 | 397.74886235657607 | | alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 357.0998545146633 | 350.5115200772392 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 281.8033924428203 | 281.98601309215843 | | document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 282.56595134222135 | 277.4565795466672 | | noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 408.89838018149516 | 405.14531386840076 | | causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 396.07662058160264 | 393.4598228299578 | | alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 317.8822887267849 | 304.754931401036 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 265.8801304948243 | 254.22961974295112 | | document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 227.87390579965614 | 222.19481980110393 | | noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 427.36821778477025 | 431.3766620314935 | | causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 410.67994346825 | 423.4666944003808 | | alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 381.1968748374038 | 381.77668006420424 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 292.5540046358546 | 296.5439130720502 | | document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 321.04573768858114 | 310.7423616656888 | | noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 427.46148866769903 | 426.162091037068 | | causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 419.75580537687347 | 421.88640120274334 | | alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 337.3208051798903 | 327.4912454675092 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 276.5638854539581 | 262.988360558083 | | document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 250.82791326036886 | 245.07367032501736 | | noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 435.8055824506086 | 441.8803729460534 | | causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 432.02638235921006 | 450.33161016596273 | | alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 402.25525939224883 | 393.8564689669916 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 297.5337286675904 | 297.0131881135074 | | document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 343.8697037899545 | 329.8194073407783 | | noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 267.58912366821056 | 256.91606054118375 | | causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 150.81723692609629 | 146.32172267858743 | | alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 129.51029293209245 | 122.72144394093334 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 147.627656359087 | 141.68956350566188 | | document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 87.55100546003591 | 84.91293287692788 | | noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 299.5931492743986 | 305.884253766691 | | causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 179.39026367843837 | 181.64741311605096 | | alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 173.93547669282367 | 173.23972950980564 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 185.90234171599252 | 182.80844545446686 | | document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 128.08176696266082 | 123.27722685662111 | | noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 340.50674552770664 | 338.9071088484576 | | causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 225.4438318650432 | 230.22899884832975 | | alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 194.15123248528312 | 185.02793973094865 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 200.74289714108176 | 191.76606719670647 | | document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 107.03564946728423 | 106.82432377861258 | | noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 371.31799283918406 | 379.7555394732925 | | causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 275.97762744310455 | 276.71106853992995 | | alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 261.6648679783462 | 259.4127232060398 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 237.03108223577615 | 233.92710216149527 | | document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 172.13926800371152 | 168.74390922407585 | | noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 381.50199487767276 | 383.9043681999597 | | causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 307.9748883093411 | 312.2403515462001 | | alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 251.11319684705438 | 243.17870127827277 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 236.3253127246763 | 223.81250201769552 | | document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 154.55693991756874 | 153.11360584987685 | | noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 407.11400078586615 | 413.53709886086557 | | causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 348.1705797722622 | 360.09771155957367 | | alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 321.8593280850388 | 318.2882327401255 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 270.089032013835 | 268.767323026064 | | document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 238.07324557907788 | 228.09842078362692 | | noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 399.8172853171901 | 401.0954526332136 | | causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 363.4387330438581 | 364.13111024232677 | | alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 294.1752429133857 | 283.7235663368415 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 256.8389394007649 | 246.91771015606483 | | document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 199.3378564292656 | 192.40439590901758 | | noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 425.5150965556111 | 430.8190098707553 | | causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 396.00437184073013 | 411.3873625655787 | | alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 369.92803661607815 | 361.43244467343663 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 293.4277354412933 | 295.2529537595746 | | document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 288.0208673072841 | 281.51896404878863 | | noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 408.3005367220567 | 408.96116482298913 | | causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 396.90095962766304 | 396.87385456176486 | | alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 319.0534576137999 | 302.50950358107764 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 270.3334977708081 | 258.8506349486557 | | document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 227.46824134365394 | 222.23759438128766 | | noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 438.24247309479694 | 437.7975163205371 | | causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 428.34012029699227 | 433.3215899950434 | | alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 386.52672049728875 | 388.26216893354984 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 302.71976814728083 | 302.3574867306459 | | document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 327.39760662780986 | 308.6348428844912 | | noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 423.31308678262695 | 426.6306972137279 | | causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 412.6983690923106 | 419.4961977664297 | | alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 337.41003544742273 | 324.2155049126126 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 278.7755890910794 | 265.9194286636502 | | document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 251.55678254755364 | 244.8843180141462 | | noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 452.5930781172308 | 457.7117122300742 | | causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 445.05676260348116 | 463.9304535499636 | | alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 415.78302138389415 | 406.29229555271456 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 308.0311067300895 | 304.91354721414314 | | document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 351.43943626809335 | 329.4476923070317 | | noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 295.1801525813241 | 291.36521287398904 | | causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 183.23250549178067 | 182.35421238887605 | | alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 151.56832453117747 | 151.3422139154794 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 171.02111935180432 | 160.72516856727913 | | document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 74.05765122783826 | 74.5885345035243 | | noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 314.3587394591763 | 319.2938677773619 | | causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 224.57002084153177 | 225.48868542008177 | | alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.00964804143052 | 215.39576159953486 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.1174237618258 | 214.28437413525663 | | document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 121.08920423648368 | 119.55813661872644 | | noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 362.2193857281911 | 360.05005804275936 | | causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 279.8840217430121 | 279.5437918286659 | | alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 227.76617121021982 | 222.8655938229316 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 215.43141176970562 | 207.71852284994702 | | document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 121.35588364218539 | 121.20636565046884 | | noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 365.1545280898012 | 373.37585444987326 | | causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 304.360119952975 | 309.1247297936263 | | alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 287.2603904544586 | 289.25547903162595 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 257.9852675272418 | 257.59069234098115 | | document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 188.35158496670232 | 184.24683960154857 | | noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 389.9744911369211 | 388.43466897254166 | | causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 345.9228295166513 | 342.63034895210126 | | alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 279.56334658247437 | 271.2724375402088 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 245.66477202810066 | 233.49688207371258 | | document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 170.3270720653187 | 166.23863845657382 | | noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 400.0041140827554 | 402.11182445396497 | | causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 363.64641830327434 | 375.9288663364792 | | alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 341.5776139573363 | 335.1160003213424 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 281.1811770268521 | 280.21438270014005 | | document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 247.78716118997716 | 245.3269825179633 | | noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 403.794126680488 | 405.2353919019577 | | causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 387.079178426863 | 385.1461762057035 | | alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 309.7847188173431 | 298.0443968374749 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 262.4721750159666 | 250.81679725428586 | | document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 205.70866004479979 | 202.9620839129557 | | noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 413.380982988662 | 418.40270594263103 | | causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 398.450064800682 | 409.6794973994029 | | alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 372.26297458194466 | 364.44415106552196 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 293.0818569905912 | 292.85172400643984 | | document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 296.46717085592087 | 285.76362010612763 | | noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 419.3186786037592 | 426.08801580934437 | | causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 408.1648467766632 | 409.4122254207817 | | alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 329.24396020457345 | 313.5200995121138 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 274.61257504571876 | 255.7801815432177 | | document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 232.63806001220684 | 230.03020843492314 | | noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 435.0785891054788 | 440.39101804225345 | | causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 424.86925312752817 | 435.18898057396825 | | alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 393.000417896268 | 395.11543361225256 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 297.7755459218185 | 300.7208114715287 | | document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 331.71570861760534 | 318.07127352552885 | | noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 424.58602747137405 | 425.84897078470715 | | causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 422.66607285025725 | 423.5524945535485 | | alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 344.8625760048626 | 331.6793888458635 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 282.0787281511649 | 263.7895634445868 | | document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 252.7301927385177 | 245.41844170037427 | | noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 437.0658069164588 | 442.9101960063628 | | causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 433.13788271434646 | 452.3873572709863 | | alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 404.0959191546953 | 396.7077863894884 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 300.45502211883206 | 301.3439134717943 | | document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 344.11003202413934 | 330.8897663350314 | | noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 298.4364205341705 | 291.6793556507056 | | causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 187.6382133139633 | 191.05409897308772 | | alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 156.55822078636112 | 154.178925976516 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 173.47765221825162 | 169.30862508068464 | | document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 74.5885345035243 | 74.52689061607104 | | noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 323.12233826013045 | 328.53889207933514 | | causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 236.75872140126316 | 235.8378325547398 | | alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 227.17836523816675 | 226.75357076139966 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 224.07209453308036 | 224.07209453308036 | | document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 122.85572156047981 | 121.11642183704716 | | noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 361.3123326658092 | 360.71014086458337 | | causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 281.5287983927017 | 281.94301754758345 | | alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 232.7456696285686 | 226.50976826432776 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 221.5612361744038 | 214.96188822837055 | | document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 121.38311528944315 | 120.85441868178513 | | noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 380.2579019244734 | 389.2520157863988 | | causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 316.95230660496924 | 317.87597790618906 | | alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 301.07968126657323 | 298.02424098422983 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 267.2240756921594 | 267.16353549228154 | | document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 189.82761622494257 | 186.736450261963 | | noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 389.88665375406805 | 387.9125133037077 | | causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 348.70619958684887 | 346.6750499749774 | | alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 280.5472989906087 | 271.22300822012187 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 250.02397620165968 | 241.22532776331445 | | document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 171.67817496107645 | 166.95679280483972 | | noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 412.626880230807 | 417.60238657950777 | | causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 374.8829313933945 | 389.4448546468815 | | alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 353.20410434172436 | 345.7072490717473 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 292.51045924209586 | 291.66621022138287 | | document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 251.6264062063495 | 248.45110052911542 | | noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 404.0155784550126 | 401.90546837237514 | | causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 384.4389015599863 | 386.9684324594344 | | alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 313.3731284132225 | 298.17074251037894 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 264.19199737284265 | 252.8982463999916 | | document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 207.03696315185684 | 202.86697323136772 | | noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 428.2436763312506 | 433.45005568619536 | | causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 411.8516531869893 | 428.2753623461049 | | alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 384.9095037182509 | 372.90888743000744 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 303.2438915629836 | 302.05095952914337 | | document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 301.8689122735564 | 285.0363190513223 | | noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 423.13592231504805 | 420.3991500185611 | | causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 407.44527331585493 | 408.5064370765247 | | alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 330.50050996167414 | 316.8763979925965 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 274.6833786307413 | 259.86098862141324 | | document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 232.24019584158367 | 226.52040268160232 | | noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 444.4596314237808 | 455.99558915752266 | | causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 437.4245561244369 | 455.98275147271966 | | alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 397.3350686877605 | 397.88875599028063 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 308.53809114394545 | 307.1359822042007 | | document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 331.32379843423774 | 316.85293191675646 | | noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 422.4622274366379 | 425.0407156418684 | | causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 420.9547052783101 | 430.33779243510276 | | alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 345.50265346504085 | 332.094855328957 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 280.81715528243365 | 264.6543640282054 | | document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 252.25635200421783 | 245.46235499490305 | | noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 452.5524207341139 | 461.7512032176736 | | causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 445.2316469907137 | 464.4523799578466 | | alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 416.87264016717023 | 409.17124592157046 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 309.42579489389846 | 307.9734464665731 | | document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 350.50782004300623 | 330.98959545427294 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/157767 Approved by: https://github.com/Skylion007 |
||
|
|
262640fd22 |
[ROCm][CI] restore test_flex_attention tests (#161519)
Reverts #161450 and targets specific subtests to skip on MI200. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161519 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
||
|
|
818ba434c7 |
Revert "Ensure large tensor int32 -> int64 indexing is enabled (#157767)"
This reverts commit
|
||
|
|
fc69c2bc67 |
Ensure large tensor int32 -> int64 indexing is enabled (#157767)
Fixes: #https://github.com/pytorch/pytorch/issues/157446 I think that this delta is worth the switch form block-ptrs especially since they are deprecated ## Perf Summary A is nightly B is this diff, so `negative` means this diff improves perf TOP 5 differences <img width="805" height="754" alt="Screenshot 2025-08-24 at 5 49 49 PM" src="https://github.com/user-attachments/assets/aa359cdf-ee9a-427d-be72-1b9aef6f3115" /> <details> <summary><strong>Full perf table (click to expand)</strong></summary> | attn_type | dtype | shape(B,Hq,M,Hkv,N,D) | TFlops Version A | TFlops Version B | | --- | --- | --- | --- | --- | | noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 258.38834144791923 | 258.6353685004612 | | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.2192450677751 | 140.12393320464972 | | alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 122.32683823617003 | 118.51603755647925 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.48556906165314 | 137.24259849208627 | | document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 86.59814488695922 | 84.59431398586257 | | noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 288.52679758135764 | 292.9174195871856 | | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 172.25541683643277 | 172.94326459828508 | | alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 164.40864610599826 | 165.035129576335 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 176.54876886433945 | 175.08057670028145 | | document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 125.22491679812626 | 121.06201152859151 | | noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 339.11952481874283 | 339.0132835601695 | | causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 227.58583240284406 | 228.21824999409597 | | alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 185.98569659868966 | 182.32850843255093 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 188.9495725191772 | 180.31385312481657 | | document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 106.25789530994302 | 106.55084959448476 | | noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 357.6430536888533 | 363.30843452247274 | | causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 262.3241154406613 | 265.73250045488 | | alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 249.30498953911416 | 249.35928192833785 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 224.74126243851808 | 223.71776504077988 | | document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 168.26977014013707 | 165.47991483333809 | | noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 382.8178701785897 | 384.34752965862685 | | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 308.1449710013853 | 311.0653716044644 | | alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 251.96365252505072 | 243.92283557225903 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 226.69316232745368 | 215.22769268913356 | | document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 153.34142545296405 | 151.9312673939401 | | noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 396.0998000753126 | 398.35036286102473 | | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 333.5198415274966 | 344.6354466169716 | | alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 310.5955933379696 | 305.66347819546 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 260.4012412689896 | 259.758666997307 | | document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 234.13034252182635 | 227.61676497283614 | | noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 396.17615538477196 | 401.1419104525502 | | causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 359.98648311998414 | 360.8285563463094 | | alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 291.97720707257736 | 281.41694809965253 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 250.1703628419691 | 238.556760291579 | | document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 199.50782826294306 | 191.52327358439223 | | noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 411.0632004785396 | 413.6362648405517 | | causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 382.9404387613185 | 397.74886235657607 | | alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 357.0998545146633 | 350.5115200772392 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 281.8033924428203 | 281.98601309215843 | | document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 282.56595134222135 | 277.4565795466672 | | noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 408.89838018149516 | 405.14531386840076 | | causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 396.07662058160264 | 393.4598228299578 | | alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 317.8822887267849 | 304.754931401036 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 265.8801304948243 | 254.22961974295112 | | document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 227.87390579965614 | 222.19481980110393 | | noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 427.36821778477025 | 431.3766620314935 | | causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 410.67994346825 | 423.4666944003808 | | alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 381.1968748374038 | 381.77668006420424 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 292.5540046358546 | 296.5439130720502 | | document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 321.04573768858114 | 310.7423616656888 | | noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 427.46148866769903 | 426.162091037068 | | causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 419.75580537687347 | 421.88640120274334 | | alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 337.3208051798903 | 327.4912454675092 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 276.5638854539581 | 262.988360558083 | | document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 250.82791326036886 | 245.07367032501736 | | noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 435.8055824506086 | 441.8803729460534 | | causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 432.02638235921006 | 450.33161016596273 | | alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 402.25525939224883 | 393.8564689669916 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 297.5337286675904 | 297.0131881135074 | | document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 343.8697037899545 | 329.8194073407783 | | noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 267.58912366821056 | 256.91606054118375 | | causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 150.81723692609629 | 146.32172267858743 | | alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 129.51029293209245 | 122.72144394093334 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 147.627656359087 | 141.68956350566188 | | document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 87.55100546003591 | 84.91293287692788 | | noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 299.5931492743986 | 305.884253766691 | | causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 179.39026367843837 | 181.64741311605096 | | alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 173.93547669282367 | 173.23972950980564 | | sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 185.90234171599252 | 182.80844545446686 | | document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 128.08176696266082 | 123.27722685662111 | | noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 340.50674552770664 | 338.9071088484576 | | causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 225.4438318650432 | 230.22899884832975 | | alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 194.15123248528312 | 185.02793973094865 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 200.74289714108176 | 191.76606719670647 | | document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 107.03564946728423 | 106.82432377861258 | | noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 371.31799283918406 | 379.7555394732925 | | causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 275.97762744310455 | 276.71106853992995 | | alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 261.6648679783462 | 259.4127232060398 | | sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 237.03108223577615 | 233.92710216149527 | | document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 172.13926800371152 | 168.74390922407585 | | noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 381.50199487767276 | 383.9043681999597 | | causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 307.9748883093411 | 312.2403515462001 | | alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 251.11319684705438 | 243.17870127827277 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 236.3253127246763 | 223.81250201769552 | | document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 154.55693991756874 | 153.11360584987685 | | noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 407.11400078586615 | 413.53709886086557 | | causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 348.1705797722622 | 360.09771155957367 | | alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 321.8593280850388 | 318.2882327401255 | | sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 270.089032013835 | 268.767323026064 | | document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 238.07324557907788 | 228.09842078362692 | | noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 399.8172853171901 | 401.0954526332136 | | causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 363.4387330438581 | 364.13111024232677 | | alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 294.1752429133857 | 283.7235663368415 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 256.8389394007649 | 246.91771015606483 | | document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 199.3378564292656 | 192.40439590901758 | | noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 425.5150965556111 | 430.8190098707553 | | causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 396.00437184073013 | 411.3873625655787 | | alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 369.92803661607815 | 361.43244467343663 | | sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 293.4277354412933 | 295.2529537595746 | | document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 288.0208673072841 | 281.51896404878863 | | noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 408.3005367220567 | 408.96116482298913 | | causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 396.90095962766304 | 396.87385456176486 | | alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 319.0534576137999 | 302.50950358107764 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 270.3334977708081 | 258.8506349486557 | | document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 227.46824134365394 | 222.23759438128766 | | noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 438.24247309479694 | 437.7975163205371 | | causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 428.34012029699227 | 433.3215899950434 | | alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 386.52672049728875 | 388.26216893354984 | | sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 302.71976814728083 | 302.3574867306459 | | document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 327.39760662780986 | 308.6348428844912 | | noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 423.31308678262695 | 426.6306972137279 | | causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 412.6983690923106 | 419.4961977664297 | | alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 337.41003544742273 | 324.2155049126126 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 278.7755890910794 | 265.9194286636502 | | document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 251.55678254755364 | 244.8843180141462 | | noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 452.5930781172308 | 457.7117122300742 | | causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 445.05676260348116 | 463.9304535499636 | | alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 415.78302138389415 | 406.29229555271456 | | sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 308.0311067300895 | 304.91354721414314 | | document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 351.43943626809335 | 329.4476923070317 | | noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 295.1801525813241 | 291.36521287398904 | | causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 183.23250549178067 | 182.35421238887605 | | alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 151.56832453117747 | 151.3422139154794 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 171.02111935180432 | 160.72516856727913 | | document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 74.05765122783826 | 74.5885345035243 | | noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 314.3587394591763 | 319.2938677773619 | | causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 224.57002084153177 | 225.48868542008177 | | alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.00964804143052 | 215.39576159953486 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.1174237618258 | 214.28437413525663 | | document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 121.08920423648368 | 119.55813661872644 | | noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 362.2193857281911 | 360.05005804275936 | | causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 279.8840217430121 | 279.5437918286659 | | alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 227.76617121021982 | 222.8655938229316 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 215.43141176970562 | 207.71852284994702 | | document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 121.35588364218539 | 121.20636565046884 | | noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 365.1545280898012 | 373.37585444987326 | | causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 304.360119952975 | 309.1247297936263 | | alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 287.2603904544586 | 289.25547903162595 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 257.9852675272418 | 257.59069234098115 | | document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 188.35158496670232 | 184.24683960154857 | | noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 389.9744911369211 | 388.43466897254166 | | causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 345.9228295166513 | 342.63034895210126 | | alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 279.56334658247437 | 271.2724375402088 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 245.66477202810066 | 233.49688207371258 | | document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 170.3270720653187 | 166.23863845657382 | | noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 400.0041140827554 | 402.11182445396497 | | causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 363.64641830327434 | 375.9288663364792 | | alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 341.5776139573363 | 335.1160003213424 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 281.1811770268521 | 280.21438270014005 | | document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 247.78716118997716 | 245.3269825179633 | | noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 403.794126680488 | 405.2353919019577 | | causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 387.079178426863 | 385.1461762057035 | | alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 309.7847188173431 | 298.0443968374749 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 262.4721750159666 | 250.81679725428586 | | document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 205.70866004479979 | 202.9620839129557 | | noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 413.380982988662 | 418.40270594263103 | | causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 398.450064800682 | 409.6794973994029 | | alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 372.26297458194466 | 364.44415106552196 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 293.0818569905912 | 292.85172400643984 | | document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 296.46717085592087 | 285.76362010612763 | | noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 419.3186786037592 | 426.08801580934437 | | causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 408.1648467766632 | 409.4122254207817 | | alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 329.24396020457345 | 313.5200995121138 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 274.61257504571876 | 255.7801815432177 | | document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 232.63806001220684 | 230.03020843492314 | | noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 435.0785891054788 | 440.39101804225345 | | causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 424.86925312752817 | 435.18898057396825 | | alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 393.000417896268 | 395.11543361225256 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 297.7755459218185 | 300.7208114715287 | | document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 331.71570861760534 | 318.07127352552885 | | noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 424.58602747137405 | 425.84897078470715 | | causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 422.66607285025725 | 423.5524945535485 | | alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 344.8625760048626 | 331.6793888458635 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 282.0787281511649 | 263.7895634445868 | | document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 252.7301927385177 | 245.41844170037427 | | noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 437.0658069164588 | 442.9101960063628 | | causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 433.13788271434646 | 452.3873572709863 | | alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 404.0959191546953 | 396.7077863894884 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 300.45502211883206 | 301.3439134717943 | | document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 344.11003202413934 | 330.8897663350314 | | noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 298.4364205341705 | 291.6793556507056 | | causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 187.6382133139633 | 191.05409897308772 | | alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 156.55822078636112 | 154.178925976516 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 173.47765221825162 | 169.30862508068464 | | document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 74.5885345035243 | 74.52689061607104 | | noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 323.12233826013045 | 328.53889207933514 | | causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 236.75872140126316 | 235.8378325547398 | | alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 227.17836523816675 | 226.75357076139966 | | sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 224.07209453308036 | 224.07209453308036 | | document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 122.85572156047981 | 121.11642183704716 | | noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 361.3123326658092 | 360.71014086458337 | | causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 281.5287983927017 | 281.94301754758345 | | alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 232.7456696285686 | 226.50976826432776 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 221.5612361744038 | 214.96188822837055 | | document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 121.38311528944315 | 120.85441868178513 | | noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 380.2579019244734 | 389.2520157863988 | | causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 316.95230660496924 | 317.87597790618906 | | alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 301.07968126657323 | 298.02424098422983 | | sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 267.2240756921594 | 267.16353549228154 | | document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 189.82761622494257 | 186.736450261963 | | noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 389.88665375406805 | 387.9125133037077 | | causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 348.70619958684887 | 346.6750499749774 | | alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 280.5472989906087 | 271.22300822012187 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 250.02397620165968 | 241.22532776331445 | | document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 171.67817496107645 | 166.95679280483972 | | noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 412.626880230807 | 417.60238657950777 | | causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 374.8829313933945 | 389.4448546468815 | | alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 353.20410434172436 | 345.7072490717473 | | sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 292.51045924209586 | 291.66621022138287 | | document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 251.6264062063495 | 248.45110052911542 | | noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 404.0155784550126 | 401.90546837237514 | | causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 384.4389015599863 | 386.9684324594344 | | alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 313.3731284132225 | 298.17074251037894 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 264.19199737284265 | 252.8982463999916 | | document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 207.03696315185684 | 202.86697323136772 | | noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 428.2436763312506 | 433.45005568619536 | | causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 411.8516531869893 | 428.2753623461049 | | alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 384.9095037182509 | 372.90888743000744 | | sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 303.2438915629836 | 302.05095952914337 | | document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 301.8689122735564 | 285.0363190513223 | | noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 423.13592231504805 | 420.3991500185611 | | causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 407.44527331585493 | 408.5064370765247 | | alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 330.50050996167414 | 316.8763979925965 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 274.6833786307413 | 259.86098862141324 | | document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 232.24019584158367 | 226.52040268160232 | | noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 444.4596314237808 | 455.99558915752266 | | causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 437.4245561244369 | 455.98275147271966 | | alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 397.3350686877605 | 397.88875599028063 | | sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 308.53809114394545 | 307.1359822042007 | | document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 331.32379843423774 | 316.85293191675646 | | noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 422.4622274366379 | 425.0407156418684 | | causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 420.9547052783101 | 430.33779243510276 | | alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 345.50265346504085 | 332.094855328957 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 280.81715528243365 | 264.6543640282054 | | document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 252.25635200421783 | 245.46235499490305 | | noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 452.5524207341139 | 461.7512032176736 | | causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 445.2316469907137 | 464.4523799578466 | | alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 416.87264016717023 | 409.17124592157046 | | sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 309.42579489389846 | 307.9734464665731 | | document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 350.50782004300623 | 330.98959545427294 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/157767 Approved by: https://github.com/Skylion007 |
||
|
|
3a4140bf8e |
[FlexAttention] fixing learnable bias assertion error in inductor (#161170)
Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677) We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts. In this PR, we register the buffers with the base graph to solve this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170 Approved by: https://github.com/drisspg |
||
|
|
a6bc296207 |
[FlexAttention] Update the guard semantics for divisibility (#159884)
We don't add guards unless we know (and another guard has ensured this) that this is a safe optimization Pull Request resolved: https://github.com/pytorch/pytorch/pull/159884 Approved by: https://github.com/Chillee |
||
|
|
3e5e094615 |
Revert "Fix large_tensor_test skipping cpu (#158617)"
This reverts commit |
||
|
|
debc0591b8 |
Fix large_tensor_test skipping cpu (#158617)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158617 Approved by: https://github.com/BoyuanFeng |
||
|
|
21a95bdf7c |
[Inductor] [Triton] Enabling TMA for flex-attention for supported device types (#157822)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157822 Approved by: https://github.com/drisspg ghstack dependencies: #159123 |
||
|
|
a00cd8cf25 |
Add a way to disable compile for debugging flex-attention (#158534)
Finally got around to doing this, this flag lets us do:
```Python
#!/usr/bin/env python3
"""
FlexAttention Debug: Using breakpoints and unwrap
"""
import torch
import torch.nn.attention.flex_attention as fa
unwrap = torch._C._functorch.get_unwrapped
def score_mod(score, batch, head, q_idx, kv_idx):
# Set breakpoint here to debug
breakpoint()
# In debugger, unwrap to see actual tensor values:
# >>> actual_score = unwrap(unwrap(unwrap(unwrap(score))))
# >>> actual_batch = unwrap(batch)
# >>> actual_head = unwrap(head)
# >>> actual_q_idx = unwrap(q_idx)
# >>> actual_kv_idx = unwrap(kv_idx)
# >>> print(actual_score)
# >>> print(f"q_idx: {actual_q_idx}, kv_idx: {actual_kv_idx}")
return torch.where(q_idx >= kv_idx, score, torch.tensor(float('-inf')))
def main():
# Enable debug mode
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
# Small example
B, H, S, D = 1, 2, 4, 8
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
# Run - will hit breakpoint
output = fa.flex_attention(q, k, v, score_mod=score_mod)
# Disable debug mode
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158534
Approved by: https://github.com/Chillee, https://github.com/zou3519
|
||
|
|
8c928372b3 |
Make Q Indices optional (#157997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157997 Approved by: https://github.com/BoyuanFeng, https://github.com/Chillee |
||
|
|
987314aa96 |
Split batch-num-heads grid dim between y and z (#157745)
for #157018 doesn't totally fix the problem but should help alot Pull Request resolved: https://github.com/pytorch/pytorch/pull/157745 Approved by: https://github.com/Chillee |
||
|
|
17687eb792 |
[BE][4/6] fix typos in test/ (test/inductor/) (#157638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157638 Approved by: https://github.com/yewentao256, https://github.com/jansel |
||
|
|
f5e6e52f25 |
[BE][PYFMT] migrate PYFMT for test/inductor/ to ruff format (#148186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148186 Approved by: https://github.com/jansel |
||
|
|
ccc6279b40 |
flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo (#151719)
This is enough to get @XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198 There are two problems: (1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run. This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention. Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though. (2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation. **Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic. This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151719 Approved by: https://github.com/ydwu4 Co-authored-by: drisspg <drisspguessous@gmail.com> |
||
|
|
80703ca332 |
[FlexAttention] Allow dispatch to SAC for flex (#150080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150080 Approved by: https://github.com/zou3519 |