Commit Graph

282 Commits

Author SHA1 Message Date
linhaifeng
695cb0d342 [2/N][Fix] Fix typo in test folder (#166374)
Fix typo in test folder.

_typos.toml
```bash
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166374
Approved by: https://github.com/cyyever, https://github.com/ezyang
2025-10-29 03:02:07 +00:00
drisspg
5016e7b2eb [FlexAttention] Add mechanism to get optimal autotune decision (#165817)
Script: https://github.com/meta-pytorch/attention-gym/pull/169

Feels directionally okay but there is some bike shedding / this could be quite prone to collision of keys depending on mask mod and score mod changes and simple cache key.

Usecase: https://github.com/meta-pytorch/attention-gym/pull/169

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165817
Approved by: https://github.com/Chillee
2025-10-28 15:50:12 +00:00
fduwjj
904abfc2ca Export flex attention with kwargs and DTensor (#166045)
Fixes #165948

Adding registration of the MaskBlock makes flex attention with kwargs exportable.

Also modified unittests to accept kwargs

```
python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export

python test/inductor/test_flex_attention.py -k test_pytree_
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166045
Approved by: https://github.com/drisspg, https://github.com/SherlockNoMad

Co-authored-by: fduwjj <fduwjj@gmail.com>
2025-10-27 21:40:40 +00:00
James Wu
e4c01011c2 Mark FlexAttentionBackward as cacheable (#165996)
This probably should have been marked cacheable a long time ago, no reason that it isn't.

Test Plan:
New regional inductor tests for test_flex_attention now are serializable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165996
Approved by: https://github.com/oulgen, https://github.com/zou3519, https://github.com/drisspg
2025-10-26 14:39:17 +00:00
PyTorch MergeBot
516e58965a Revert "Export flex attention with kwargs and DTensor (#166045)"
This reverts commit de7fdfe41a.

Reverted https://github.com/pytorch/pytorch/pull/166045 on behalf of https://github.com/malfet due to Broke distributed tests, see b55b779ad3/1 ([comment](https://github.com/pytorch/pytorch/pull/166045#issuecomment-3446850955))
2025-10-25 15:47:32 +00:00
Yiming Zhou
de7fdfe41a Export flex attention with kwargs and DTensor (#166045)
Fixes #165948

Adding registration of the MaskBlock makes flex attention with kwargs exportable.

Also modified unittests to accept kwargs

```
python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export

python test/inductor/test_flex_attention.py -k test_pytree_
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166045
Approved by: https://github.com/drisspg
2025-10-25 03:17:22 +00:00
Amin Sedaghat
767199fd9b [flex_attention] replace sliced BlockMask noop with helpful error (#164702)
Fixes part of #163314

After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`.

Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702
Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
2025-10-20 03:46:16 +00:00
drisspg
6b80c94901 [FlexAttention] Fix dynamic shaped heads flex_flash check (#165866)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #165729
2025-10-19 23:10:16 +00:00
Yuanyuan Chen
e925dfcc6b Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-17 07:27:11 +00:00
nullplay
ac529df244 Native matmul (#157743)
### Implementation of #151705

This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.

To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:

1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)

### Summary of This PR

This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.

### Notable Changes

1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.

@eellison @jansel @PaulZhang12 @shunting314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
2025-10-14 04:22:30 +00:00
Markus Hoehnerbach
a7fa1a91e3 fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317
Approved by: https://github.com/drisspg
ghstack dependencies: #163986
2025-10-09 15:40:49 +00:00
PyTorch MergeBot
20082d7136 Revert "fix flex attention eager bwd: more rounding (#164317)"
This reverts commit 41808b2ba9.

Reverted https://github.com/pytorch/pytorch/pull/164317 on behalf of https://github.com/jeffdaily due to inductor/test_flex_attention.py::TestFlexAttentionCUDA::test_builtin_score_mods_seqlen_lt_custom_sparse_block_size_score_mod4_cuda_float16 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18330774537/job/52207370954) [HUD commit link](41808b2ba9) ([comment](https://github.com/pytorch/pytorch/pull/164317#issuecomment-3381812090))
2025-10-08 14:29:10 +00:00
Markus Hoehnerbach
41808b2ba9 fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317
Approved by: https://github.com/drisspg
ghstack dependencies: #163986
2025-10-08 01:17:45 +00:00
PyTorch MergeBot
5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e602692.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00
Yuanyuan Chen
321e602692 Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang
2025-10-05 07:38:25 +00:00
Markus Hoehnerbach
91c4db76cb fix flex attention eager: dont round down scores to low-precision (closes #163588) (#163986)
Fixes: https://github.com/pytorch/pytorch/issues/163588

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163986
Approved by: https://github.com/drisspg, https://github.com/mlazos
2025-10-03 01:09:59 +00:00
drisspg
cfd46d13e6 Fix SAC + Flex issue (#164421)
# Summary

This happends when flex_attention is not tagged with the ` CheckpointPolicy.MUST_SAVE` policy. This causes the lse to be unrealized. I think in general this probably not the best policy but we shoudn't error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164421
Approved by: https://github.com/Skylion007
2025-10-02 09:02:17 +00:00
Yuanyuan Chen
a8c528c105 [1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
2025-09-29 01:42:01 +00:00
drisspg
e2ce79e4cc [Flex] Fix silent correctness w/ backpropping grads (#163677)
Fixes #https://github.com/pytorch/pytorch/issues/162228

# Summary

Majority of our tests are only compiling flex-attention in isolation. This means that for fake tensor propagation the input primals and all captured buffers dont do any intermediate computation below autograd.  As a result result the by happen chance match the `require_grad`ness of the eager implementation and this check  will pass. However if score_mod is a the result of some other intermediate fake tensor prop then it is not guaranteed to have accurate req_gradness, which was happening here.

TLDR is that this was a boot and suspenders that was actually harmful and we should just let the joint graph handle creating the correct joint graph

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163677
Approved by: https://github.com/ydwu4
2025-09-24 02:12:19 +00:00
Jason Ansel
ed84e808f0 [inductor] Freeze layouts in FlexAttention (#163434)
Fixes #163300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163434
Approved by: https://github.com/drisspg
ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419
2025-09-23 15:37:29 +00:00
Isalia20
1a42656d6c [Flex attention] Fix flex attention head broadcast (#163426)
Fixes part of #163314

In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results**

This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.

The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163426
Approved by: https://github.com/drisspg
2025-09-23 13:01:51 +00:00
David Berard
ff6870d134 [BE][flex attention] compute RMSE in float64 (#162088)
I saw a failure where the reference error was 0.0, and the compiled error was 0.035. Although the failure still occurs with or without this change, it was confusing to see RMSE of 0.0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162088
Approved by: https://github.com/drisspg
2025-09-11 23:53:31 +00:00
drisspg
864ffe12d7 Fix some edge cases (#162295)
``` Summary
🔝 Top 5 Performance Differences (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬─────────────────────────────┬───────────────────┬──────────────────────┬───────────────────────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)       ┆ TFlops BWD (base) ┆ TFlops BWD (no_peel) ┆ no_peel_speedup_over_base ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                         ┆ ---               ┆ ---                  ┆ ---                       ┆ ---       │
│ str            ┆ str            ┆ str                         ┆ f64               ┆ f64                  ┆ f64                       ┆ f64       │
╞════════════════╪════════════════╪═════════════════════════════╪═══════════════════╪══════════════════════╪═══════════════════════════╪═══════════╡
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64)  ┆ 56.937931         ┆ 58.960459            ┆ 1.035522                  ┆ 3.552163  │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 128) ┆ 89.221306         ┆ 86.295642            ┆ 0.967209                  ┆ -3.27911  │
│ causal         ┆ torch.bfloat16 ┆ (2, 16, 4096, 4, 4096, 128) ┆ 111.552594        ┆ 114.380841           ┆ 1.025353                  ┆ 2.535349  │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, 1024, 64) ┆ 74.830149         ┆ 76.685445            ┆ 1.024793                  ┆ 2.479344  │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64)  ┆ 55.279932         ┆ 56.369312            ┆ 1.019707                  ┆ 1.97066   │
└────────────────┴────────────────┴─────────────────────────────┴───────────────────┴──────────────────────┴───────────────────────────┴───────────┘

🔺 Top 5 Cases Where no_peel (change) is Faster than base (baseline):
shape: (5, 7)
┌────────────────┬────────────────┬─────────────────────────────┬───────────────────┬──────────────────────┬───────────────────────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)       ┆ TFlops BWD (base) ┆ TFlops BWD (no_peel) ┆ no_peel_speedup_over_base ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                         ┆ ---               ┆ ---                  ┆ ---                       ┆ ---       │
│ str            ┆ str            ┆ str                         ┆ f64               ┆ f64                  ┆ f64                       ┆ f64       │
╞════════════════╪════════════════╪═════════════════════════════╪═══════════════════╪══════════════════════╪═══════════════════════════╪═══════════╡
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64)  ┆ 56.937931         ┆ 58.960459            ┆ 1.035522                  ┆ 3.552163  │
│ causal         ┆ torch.bfloat16 ┆ (2, 16, 4096, 4, 4096, 128) ┆ 111.552594        ┆ 114.380841           ┆ 1.025353                  ┆ 2.535349  │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, 1024, 64) ┆ 74.830149         ┆ 76.685445            ┆ 1.024793                  ┆ 2.479344  │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 64)  ┆ 55.279932         ┆ 56.369312            ┆ 1.019707                  ┆ 1.97066   │
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 4096, 4, 4096, 64)  ┆ 111.08814         ┆ 112.447047           ┆ 1.012233                  ┆ 1.22327   │
└────────────────┴────────────────┴─────────────────────────────┴───────────────────┴──────────────────────┴───────────────────────────┴───────────┘

🔻 Top 5 Cases Where no_peel (change) is Slower than base (baseline):
shape: (5, 7)
┌────────────────┬────────────────┬─────────────────────────────┬───────────────────┬──────────────────────┬───────────────────────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)       ┆ TFlops BWD (base) ┆ TFlops BWD (no_peel) ┆ no_peel_speedup_over_base ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                         ┆ ---               ┆ ---                  ┆ ---                       ┆ ---       │
│ str            ┆ str            ┆ str                         ┆ f64               ┆ f64                  ┆ f64                       ┆ f64       │
╞════════════════╪════════════════╪═════════════════════════════╪═══════════════════╪══════════════════════╪═══════════════════════════╪═══════════╡
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, 1024, 128) ┆ 89.221306         ┆ 86.295642            ┆ 0.967209                  ┆ -3.27911  │
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 1024, 4, 1024, 64)  ┆ 78.23082          ┆ 76.693169            ┆ 0.980345                  ┆ -1.965531 │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 4, 2048, 128) ┆ 96.95663          ┆ 95.573333            ┆ 0.985733                  ┆ -1.426717 │
│ alibi          ┆ torch.bfloat16 ┆ (4, 16, 2048, 4, 2048, 64)  ┆ 93.373473         ┆ 92.294147            ┆ 0.988441                  ┆ -1.155924 │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 2048, 4, 2048, 128) ┆ 96.95147          ┆ 96.105389            ┆ 0.991273                  ┆ -0.872685 │
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162295
Approved by: https://github.com/mlazos, https://github.com/v0i0
2025-09-10 21:33:45 +00:00
Xingyuan Li
833997a6fd [Inductor][UT] Fix flex attention related inductor cases (#162450)
## Motivation
Fixes #162435, Fixes #162436

UT failures:
* https://github.com/pytorch/pytorch/actions/runs/17523991468/job/49772651636
* https://github.com/pytorch/pytorch/actions/runs/17523991468/job/49772651637

To fix flex attention related cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162450
Approved by: https://github.com/drisspg
2025-09-10 06:48:00 +00:00
drisspg
ac9ccd0dc2 Add return-max-scores to flex-attention (#161667)
# Summary

### Update

API

```Py
class AuxRequest(NamedTuple):
    """Request which auxiliary outputs to compute from flex_attention.

    Each field is a boolean indicating whether that auxiliary output should be computed.
    """

    lse: bool = False
    max_scores: bool = False

class AuxOutput(NamedTuple):
    """Auxiliary outputs from flex_attention operation.

    Fields will be None if not requested, or contain the tensor if requested.
    """

    lse: Optional[Tensor] = None
    max_scores: Optional[Tensor] = None

  out_only = flex_attention(query, key, value, score_mod)
  out_max, aux_max = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(max_scores=True),
  )
  out_both, aux_both = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
        )
```

Returns the max post mod scores from flex attention.

Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.

Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now

Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.

We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.

### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible  but since max is essentially 1 hot 	on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).

For now no grad, we can re-visit if needed.

## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.

```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,     ┆ 249.514658    ┆ 243.078974   ┆ 6.435684  ┆ 2.647569  │
│                ┆                ┆ 2048, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 57.971274     ┆ 56.633641    ┆ 1.337633  ┆ 2.361905  │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 280.71254     ┆ 275.686991   ┆ 5.025549  ┆ 1.822918  │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,    ┆ 152.970031    ┆ 150.489109   ┆ 2.480923  ┆ 1.648573  │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘

🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)  ┆ TFlops (base) ┆ TFlops (max) ┆ delta    ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                    ┆ ---           ┆ ---          ┆ ---      ┆ ---       │
│ str            ┆ str            ┆ str                    ┆ f64           ┆ f64          ┆ f64      ┆ f64       │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,      ┆ 249.514658    ┆ 243.078974   ┆ 6.435684 ┆ 2.647569  │
│                ┆                ┆ 2048, 64)              ┆               ┆              ┆          ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 57.971274     ┆ 56.633641    ┆ 1.337633 ┆ 2.361905  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 280.71254     ┆ 275.686991   ┆ 5.025549 ┆ 1.822918  │
│                ┆                ┆ 1024, 128)             ┆               ┆              ┆          ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,     ┆ 152.970031    ┆ 150.489109   ┆ 2.480923 ┆ 1.648573  │
│                ┆                ┆ 16384, 64)             ┆               ┆              ┆          ┆           │
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,      ┆ 161.031318    ┆ 158.597808   ┆ 2.43351  ┆ 1.534391  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘

🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 4,      ┆ 175.546923    ┆ 177.81205    ┆ -2.265127 ┆ -1.273888 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4,     ┆ 156.282597    ┆ 158.209134   ┆ -1.926537 ┆ -1.217715 │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16,     ┆ 232.542929    ┆ 235.140136   ┆ -2.597207 ┆ -1.104536 │
│                ┆                ┆ 2048, 128)            ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 169.652791    ┆ 171.475986   ┆ -1.823195 ┆ -1.063236 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
2025-09-08 22:44:48 +00:00
PyTorch MergeBot
104f2680e0 Revert "Add return-max-scores to flex-attention (#161667)"
This reverts commit 486b20b73c.

Reverted https://github.com/pytorch/pytorch/pull/161667 on behalf of https://github.com/huydhn due to Sorry for reverting your change but reverting https://github.com/pytorch/pytorch/pull/161730 does not seem to fix all trunk failures ([comment](https://github.com/pytorch/pytorch/pull/161667#issuecomment-3263512642))
2025-09-07 06:00:55 +00:00
PyTorch MergeBot
a3e5466002 Revert "Resize to 0 if not going to be used (#161730)"
This reverts commit 081cab0454.

Reverted https://github.com/pytorch/pytorch/pull/161730 on behalf of https://github.com/davidberard98 due to functorch/test_aotdispatch.py::TestAOTModuleSimplified::test_flex_attn_noncontiguous_tangents [GH job link](https://github.com/pytorch/pytorch/actions/runs/17506617662/job/49731934012) [HUD commit link](081cab0454) ([comment](https://github.com/pytorch/pytorch/pull/161730#issuecomment-3260492575))
2025-09-06 04:17:08 +00:00
drisspg
081cab0454 Resize to 0 if not going to be used (#161730)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

* __->__ #161730
*  #161667

```Py
        with torch.cuda._DeviceGuard(0):
            torch.cuda.set_device(0)
            buf0 = empty_strided_cuda((2, 32, 1024), (32768, 1024, 1), torch.float32)
            buf1 = empty_strided_cuda((2, 32, 1024), (32768, 1024, 1), torch.float32)
            buf2 = empty_strided_cuda((2, 32, 1024, 64), (2097152, 65536, 64, 1), torch.float32)
            # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
            stream0 = get_raw_stream(0)
            triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg4_1, arg3_1, arg5_1, arg6_1, buf2, 8, 2, 32, stream=stream0)
            del arg0_1
            del arg1_1
            del arg2_1
            del arg3_1
            del arg4_1
            del arg5_1
            del arg6_1
            del buf0
            del buf1
        return (buf2, )
```

Vs

```Py
        with torch.cuda._DeviceGuard(0):
            torch.cuda.set_device(0)
            buf0 = empty_strided_cuda((2, 32, 1024), (32768, 1024, 1), torch.float32)
            buf1 = empty_strided_cuda((0, ), (1, ), torch.float32)
            buf2 = empty_strided_cuda((2, 32, 1024, 64), (2097152, 65536, 64, 1), torch.float32)
            # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
            stream0 = get_raw_stream(0)
            triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg4_1, arg3_1, arg5_1, arg6_1, buf2, 8, 2, 32, stream=stream0)
            del arg0_1
            del arg1_1
            del arg2_1
            del arg3_1
            del arg4_1
            del arg5_1
            del arg6_1
            del buf0
            del buf1
        return (buf2, )
```
<img width="428" height="145" alt="Screenshot 2025-08-28 at 12 37 11 PM" src="https://github.com/user-attachments/assets/240a7bca-97e1-40c4-bf93-f075fdc1a40d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161730
Approved by: https://github.com/Skylion007, https://github.com/BoyuanFeng
ghstack dependencies: #161667
2025-09-05 23:21:46 +00:00
drisspg
486b20b73c Add return-max-scores to flex-attention (#161667)
# Summary

### Update

API

```Py
class AuxRequest(NamedTuple):
    """Request which auxiliary outputs to compute from flex_attention.

    Each field is a boolean indicating whether that auxiliary output should be computed.
    """

    lse: bool = False
    max_scores: bool = False

class AuxOutput(NamedTuple):
    """Auxiliary outputs from flex_attention operation.

    Fields will be None if not requested, or contain the tensor if requested.
    """

    lse: Optional[Tensor] = None
    max_scores: Optional[Tensor] = None

  out_only = flex_attention(query, key, value, score_mod)
  out_max, aux_max = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(max_scores=True),
  )
  out_both, aux_both = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
        )
```

Returns the max post mod scores from flex attention.

Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.

Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now

Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.

We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.

### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible  but since max is essentially 1 hot 	on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).

For now no grad, we can re-visit if needed.

## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.

```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,     ┆ 249.514658    ┆ 243.078974   ┆ 6.435684  ┆ 2.647569  │
│                ┆                ┆ 2048, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 57.971274     ┆ 56.633641    ┆ 1.337633  ┆ 2.361905  │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 280.71254     ┆ 275.686991   ┆ 5.025549  ┆ 1.822918  │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,    ┆ 152.970031    ┆ 150.489109   ┆ 2.480923  ┆ 1.648573  │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘

🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)  ┆ TFlops (base) ┆ TFlops (max) ┆ delta    ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                    ┆ ---           ┆ ---          ┆ ---      ┆ ---       │
│ str            ┆ str            ┆ str                    ┆ f64           ┆ f64          ┆ f64      ┆ f64       │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,      ┆ 249.514658    ┆ 243.078974   ┆ 6.435684 ┆ 2.647569  │
│                ┆                ┆ 2048, 64)              ┆               ┆              ┆          ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 57.971274     ┆ 56.633641    ┆ 1.337633 ┆ 2.361905  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 280.71254     ┆ 275.686991   ┆ 5.025549 ┆ 1.822918  │
│                ┆                ┆ 1024, 128)             ┆               ┆              ┆          ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,     ┆ 152.970031    ┆ 150.489109   ┆ 2.480923 ┆ 1.648573  │
│                ┆                ┆ 16384, 64)             ┆               ┆              ┆          ┆           │
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,      ┆ 161.031318    ┆ 158.597808   ┆ 2.43351  ┆ 1.534391  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘

🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 4,      ┆ 175.546923    ┆ 177.81205    ┆ -2.265127 ┆ -1.273888 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4,     ┆ 156.282597    ┆ 158.209134   ┆ -1.926537 ┆ -1.217715 │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16,     ┆ 232.542929    ┆ 235.140136   ┆ -2.597207 ┆ -1.104536 │
│                ┆                ┆ 2048, 128)            ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 169.652791    ┆ 171.475986   ┆ -1.823195 ┆ -1.063236 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
2025-09-05 23:21:46 +00:00
Tianren Gao
2fed4fb464 [FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861)
Fixes #159247
Issue 1: Accuracy Problem with Non-Divisible KV Sequences
---------------------------------------------------------

### Background

Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy.

### Root Cause
The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue.

### Solution

*   **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch

### Files Modified
*    `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod.
*   `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests
*   `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests

Issue 2: Invalid Memory Access (IMA) in Triton Kernels
------------------------------------------------------

### Background

The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes.

### Root Cause

*   Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`)
*   With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values
*   This caused out-of-bounds memory access violations

### Solution

Implemented boundary checks with early exit:

1.  **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py`

    *   Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template
2.  **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja`

    *   Boundary checks before accessing `kv_indices` in both normal and full blocks
    *   Idle CTAs with invalid `indices_idx` skip computation entirely

This prevents invalid memory access while reducing wasted computation on idle thread blocks.

Testing & Validation
--------------------

### Accuracy Tests

*   Added comprehensive test cases covering KV sequences not divisible by block sizes
*   Verified output matches standard attention for various sequence length combinations

### Sanitizer Results

`========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors`

**Before**: More than 13720 invalid memory access errors with sanitizers
**After**: Clean execution with 0 errors

Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861
Approved by: https://github.com/BoyuanFeng
2025-08-30 04:50:23 +00:00
Zhang, Liangang
3e459491b5 Enable XPU path for FlexAttention (#143553)
[#RFC153024](https://github.com/pytorch/pytorch/issues/153024)

**Motivation**

1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA  in the transformers repo on XPU device. Besides,  it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device.
2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and  different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256.

**What does this PR do?**

 1. Enable the device type for Flexattention kernel  and UTs to ensure all important UTs pass on XPU device.
 2. For E2E model inference, ensure the functionality  of LLM models inference with FlexAttention to be ready.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143553
Approved by: https://github.com/EikanWang, https://github.com/drisspg

Co-authored-by: Mao Yunfei <yunfei.mao@intel.com>
Co-authored-by: Xingyuan Li <xingyuan.li@intel.com>
Co-authored-by: majing <jing1.ma@intel.com>
Co-authored-by: Xiao, Wang <wang.xiao@intel.com>
2025-08-29 23:10:58 +00:00
PyTorch MergeBot
ef0483d74c Revert "Ensure large tensor int32 -> int64 indexing is enabled (#157767)"
This reverts commit b36a20d368.

Reverted https://github.com/pytorch/pytorch/pull/157767 on behalf of https://github.com/atalman due to need to revert https://github.com/pytorch/pytorch/pull/157767 internal tests ([comment](https://github.com/pytorch/pytorch/pull/157767#issuecomment-3233558168))
2025-08-28 13:44:41 +00:00
PyTorch MergeBot
5432966253 Revert "Remove test since it ooms on CI (#161644)"
This reverts commit 443452ca2f.

Reverted https://github.com/pytorch/pytorch/pull/161644 on behalf of https://github.com/atalman due to need to revert https://github.com/pytorch/pytorch/pull/157767 internal tests ([comment](https://github.com/pytorch/pytorch/pull/161644#issuecomment-3233550883))
2025-08-28 13:41:58 +00:00
drisspg
443452ca2f Remove test since it ooms on CI (#161644)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161644
Approved by: https://github.com/BoyuanFeng
2025-08-27 19:11:29 +00:00
drisspg
b36a20d368 Ensure large tensor int32 -> int64 indexing is enabled (#157767)
Fixes: #https://github.com/pytorch/pytorch/issues/157446

I think that this delta is worth the switch form block-ptrs especially since they are deprecated

## Perf Summary

A is nightly B is this diff, so `negative` means this diff improves perf

TOP 5 differences
<img width="805" height="754" alt="Screenshot 2025-08-24 at 5 49 49 PM" src="https://github.com/user-attachments/assets/aa359cdf-ee9a-427d-be72-1b9aef6f3115" />

<details>
  <summary><strong>Full perf table (click to expand)</strong></summary>

| attn_type | dtype | shape(B,Hq,M,Hkv,N,D) | TFlops Version A | TFlops Version B |
| --- | --- | --- | --- | --- |
| noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 258.38834144791923 | 258.6353685004612 |
| causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.2192450677751 | 140.12393320464972 |
| alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 122.32683823617003 | 118.51603755647925 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.48556906165314 | 137.24259849208627 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 86.59814488695922 | 84.59431398586257 |
| noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 288.52679758135764 | 292.9174195871856 |
| causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 172.25541683643277 | 172.94326459828508 |
| alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 164.40864610599826 | 165.035129576335 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 176.54876886433945 | 175.08057670028145 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 125.22491679812626 | 121.06201152859151 |
| noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 339.11952481874283 | 339.0132835601695 |
| causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 227.58583240284406 | 228.21824999409597 |
| alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 185.98569659868966 | 182.32850843255093 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 188.9495725191772 | 180.31385312481657 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 106.25789530994302 | 106.55084959448476 |
| noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 357.6430536888533 | 363.30843452247274 |
| causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 262.3241154406613 | 265.73250045488 |
| alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 249.30498953911416 | 249.35928192833785 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 224.74126243851808 | 223.71776504077988 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 168.26977014013707 | 165.47991483333809 |
| noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 382.8178701785897 | 384.34752965862685 |
| causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 308.1449710013853 | 311.0653716044644 |
| alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 251.96365252505072 | 243.92283557225903 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 226.69316232745368 | 215.22769268913356 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 153.34142545296405 | 151.9312673939401 |
| noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 396.0998000753126 | 398.35036286102473 |
| causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 333.5198415274966 | 344.6354466169716 |
| alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 310.5955933379696 | 305.66347819546 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 260.4012412689896 | 259.758666997307 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 234.13034252182635 | 227.61676497283614 |
| noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 396.17615538477196 | 401.1419104525502 |
| causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 359.98648311998414 | 360.8285563463094 |
| alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 291.97720707257736 | 281.41694809965253 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 250.1703628419691 | 238.556760291579 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 199.50782826294306 | 191.52327358439223 |
| noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 411.0632004785396 | 413.6362648405517 |
| causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 382.9404387613185 | 397.74886235657607 |
| alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 357.0998545146633 | 350.5115200772392 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 281.8033924428203 | 281.98601309215843 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 282.56595134222135 | 277.4565795466672 |
| noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 408.89838018149516 | 405.14531386840076 |
| causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 396.07662058160264 | 393.4598228299578 |
| alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 317.8822887267849 | 304.754931401036 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 265.8801304948243 | 254.22961974295112 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 227.87390579965614 | 222.19481980110393 |
| noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 427.36821778477025 | 431.3766620314935 |
| causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 410.67994346825 | 423.4666944003808 |
| alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 381.1968748374038 | 381.77668006420424 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 292.5540046358546 | 296.5439130720502 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 321.04573768858114 | 310.7423616656888 |
| noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 427.46148866769903 | 426.162091037068 |
| causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 419.75580537687347 | 421.88640120274334 |
| alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 337.3208051798903 | 327.4912454675092 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 276.5638854539581 | 262.988360558083 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 250.82791326036886 | 245.07367032501736 |
| noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 435.8055824506086 | 441.8803729460534 |
| causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 432.02638235921006 | 450.33161016596273 |
| alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 402.25525939224883 | 393.8564689669916 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 297.5337286675904 | 297.0131881135074 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 343.8697037899545 | 329.8194073407783 |
| noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 267.58912366821056 | 256.91606054118375 |
| causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 150.81723692609629 | 146.32172267858743 |
| alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 129.51029293209245 | 122.72144394093334 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 147.627656359087 | 141.68956350566188 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 87.55100546003591 | 84.91293287692788 |
| noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 299.5931492743986 | 305.884253766691 |
| causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 179.39026367843837 | 181.64741311605096 |
| alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 173.93547669282367 | 173.23972950980564 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 185.90234171599252 | 182.80844545446686 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 128.08176696266082 | 123.27722685662111 |
| noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 340.50674552770664 | 338.9071088484576 |
| causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 225.4438318650432 | 230.22899884832975 |
| alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 194.15123248528312 | 185.02793973094865 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 200.74289714108176 | 191.76606719670647 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 107.03564946728423 | 106.82432377861258 |
| noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 371.31799283918406 | 379.7555394732925 |
| causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 275.97762744310455 | 276.71106853992995 |
| alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 261.6648679783462 | 259.4127232060398 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 237.03108223577615 | 233.92710216149527 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 172.13926800371152 | 168.74390922407585 |
| noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 381.50199487767276 | 383.9043681999597 |
| causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 307.9748883093411 | 312.2403515462001 |
| alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 251.11319684705438 | 243.17870127827277 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 236.3253127246763 | 223.81250201769552 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 154.55693991756874 | 153.11360584987685 |
| noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 407.11400078586615 | 413.53709886086557 |
| causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 348.1705797722622 | 360.09771155957367 |
| alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 321.8593280850388 | 318.2882327401255 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 270.089032013835 | 268.767323026064 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 238.07324557907788 | 228.09842078362692 |
| noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 399.8172853171901 | 401.0954526332136 |
| causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 363.4387330438581 | 364.13111024232677 |
| alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 294.1752429133857 | 283.7235663368415 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 256.8389394007649 | 246.91771015606483 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 199.3378564292656 | 192.40439590901758 |
| noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 425.5150965556111 | 430.8190098707553 |
| causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 396.00437184073013 | 411.3873625655787 |
| alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 369.92803661607815 | 361.43244467343663 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 293.4277354412933 | 295.2529537595746 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 288.0208673072841 | 281.51896404878863 |
| noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 408.3005367220567 | 408.96116482298913 |
| causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 396.90095962766304 | 396.87385456176486 |
| alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 319.0534576137999 | 302.50950358107764 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 270.3334977708081 | 258.8506349486557 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 227.46824134365394 | 222.23759438128766 |
| noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 438.24247309479694 | 437.7975163205371 |
| causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 428.34012029699227 | 433.3215899950434 |
| alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 386.52672049728875 | 388.26216893354984 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 302.71976814728083 | 302.3574867306459 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 327.39760662780986 | 308.6348428844912 |
| noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 423.31308678262695 | 426.6306972137279 |
| causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 412.6983690923106 | 419.4961977664297 |
| alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 337.41003544742273 | 324.2155049126126 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 278.7755890910794 | 265.9194286636502 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 251.55678254755364 | 244.8843180141462 |
| noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 452.5930781172308 | 457.7117122300742 |
| causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 445.05676260348116 | 463.9304535499636 |
| alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 415.78302138389415 | 406.29229555271456 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 308.0311067300895 | 304.91354721414314 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 351.43943626809335 | 329.4476923070317 |
| noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 295.1801525813241 | 291.36521287398904 |
| causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 183.23250549178067 | 182.35421238887605 |
| alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 151.56832453117747 | 151.3422139154794 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 171.02111935180432 | 160.72516856727913 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 74.05765122783826 | 74.5885345035243 |
| noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 314.3587394591763 | 319.2938677773619 |
| causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 224.57002084153177 | 225.48868542008177 |
| alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.00964804143052 | 215.39576159953486 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.1174237618258 | 214.28437413525663 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 121.08920423648368 | 119.55813661872644 |
| noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 362.2193857281911 | 360.05005804275936 |
| causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 279.8840217430121 | 279.5437918286659 |
| alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 227.76617121021982 | 222.8655938229316 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 215.43141176970562 | 207.71852284994702 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 121.35588364218539 | 121.20636565046884 |
| noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 365.1545280898012 | 373.37585444987326 |
| causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 304.360119952975 | 309.1247297936263 |
| alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 287.2603904544586 | 289.25547903162595 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 257.9852675272418 | 257.59069234098115 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 188.35158496670232 | 184.24683960154857 |
| noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 389.9744911369211 | 388.43466897254166 |
| causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 345.9228295166513 | 342.63034895210126 |
| alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 279.56334658247437 | 271.2724375402088 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 245.66477202810066 | 233.49688207371258 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 170.3270720653187 | 166.23863845657382 |
| noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 400.0041140827554 | 402.11182445396497 |
| causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 363.64641830327434 | 375.9288663364792 |
| alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 341.5776139573363 | 335.1160003213424 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 281.1811770268521 | 280.21438270014005 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 247.78716118997716 | 245.3269825179633 |
| noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 403.794126680488 | 405.2353919019577 |
| causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 387.079178426863 | 385.1461762057035 |
| alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 309.7847188173431 | 298.0443968374749 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 262.4721750159666 | 250.81679725428586 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 205.70866004479979 | 202.9620839129557 |
| noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 413.380982988662 | 418.40270594263103 |
| causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 398.450064800682 | 409.6794973994029 |
| alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 372.26297458194466 | 364.44415106552196 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 293.0818569905912 | 292.85172400643984 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 296.46717085592087 | 285.76362010612763 |
| noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 419.3186786037592 | 426.08801580934437 |
| causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 408.1648467766632 | 409.4122254207817 |
| alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 329.24396020457345 | 313.5200995121138 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 274.61257504571876 | 255.7801815432177 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 232.63806001220684 | 230.03020843492314 |
| noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 435.0785891054788 | 440.39101804225345 |
| causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 424.86925312752817 | 435.18898057396825 |
| alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 393.000417896268 | 395.11543361225256 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 297.7755459218185 | 300.7208114715287 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 331.71570861760534 | 318.07127352552885 |
| noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 424.58602747137405 | 425.84897078470715 |
| causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 422.66607285025725 | 423.5524945535485 |
| alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 344.8625760048626 | 331.6793888458635 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 282.0787281511649 | 263.7895634445868 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 252.7301927385177 | 245.41844170037427 |
| noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 437.0658069164588 | 442.9101960063628 |
| causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 433.13788271434646 | 452.3873572709863 |
| alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 404.0959191546953 | 396.7077863894884 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 300.45502211883206 | 301.3439134717943 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 344.11003202413934 | 330.8897663350314 |
| noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 298.4364205341705 | 291.6793556507056 |
| causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 187.6382133139633 | 191.05409897308772 |
| alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 156.55822078636112 | 154.178925976516 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 173.47765221825162 | 169.30862508068464 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 74.5885345035243 | 74.52689061607104 |
| noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 323.12233826013045 | 328.53889207933514 |
| causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 236.75872140126316 | 235.8378325547398 |
| alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 227.17836523816675 | 226.75357076139966 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 224.07209453308036 | 224.07209453308036 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 122.85572156047981 | 121.11642183704716 |
| noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 361.3123326658092 | 360.71014086458337 |
| causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 281.5287983927017 | 281.94301754758345 |
| alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 232.7456696285686 | 226.50976826432776 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 221.5612361744038 | 214.96188822837055 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 121.38311528944315 | 120.85441868178513 |
| noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 380.2579019244734 | 389.2520157863988 |
| causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 316.95230660496924 | 317.87597790618906 |
| alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 301.07968126657323 | 298.02424098422983 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 267.2240756921594 | 267.16353549228154 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 189.82761622494257 | 186.736450261963 |
| noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 389.88665375406805 | 387.9125133037077 |
| causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 348.70619958684887 | 346.6750499749774 |
| alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 280.5472989906087 | 271.22300822012187 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 250.02397620165968 | 241.22532776331445 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 171.67817496107645 | 166.95679280483972 |
| noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 412.626880230807 | 417.60238657950777 |
| causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 374.8829313933945 | 389.4448546468815 |
| alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 353.20410434172436 | 345.7072490717473 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 292.51045924209586 | 291.66621022138287 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 251.6264062063495 | 248.45110052911542 |
| noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 404.0155784550126 | 401.90546837237514 |
| causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 384.4389015599863 | 386.9684324594344 |
| alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 313.3731284132225 | 298.17074251037894 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 264.19199737284265 | 252.8982463999916 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 207.03696315185684 | 202.86697323136772 |
| noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 428.2436763312506 | 433.45005568619536 |
| causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 411.8516531869893 | 428.2753623461049 |
| alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 384.9095037182509 | 372.90888743000744 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 303.2438915629836 | 302.05095952914337 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 301.8689122735564 | 285.0363190513223 |
| noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 423.13592231504805 | 420.3991500185611 |
| causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 407.44527331585493 | 408.5064370765247 |
| alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 330.50050996167414 | 316.8763979925965 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 274.6833786307413 | 259.86098862141324 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 232.24019584158367 | 226.52040268160232 |
| noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 444.4596314237808 | 455.99558915752266 |
| causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 437.4245561244369 | 455.98275147271966 |
| alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 397.3350686877605 | 397.88875599028063 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 308.53809114394545 | 307.1359822042007 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 331.32379843423774 | 316.85293191675646 |
| noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 422.4622274366379 | 425.0407156418684 |
| causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 420.9547052783101 | 430.33779243510276 |
| alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 345.50265346504085 | 332.094855328957 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 280.81715528243365 | 264.6543640282054 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 252.25635200421783 | 245.46235499490305 |
| noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 452.5524207341139 | 461.7512032176736 |
| causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 445.2316469907137 | 464.4523799578466 |
| alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 416.87264016717023 | 409.17124592157046 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 309.42579489389846 | 307.9734464665731 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 350.50782004300623 | 330.98959545427294 |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157767
Approved by: https://github.com/Skylion007
2025-08-27 02:45:20 +00:00
Jeff Daily
262640fd22 [ROCm][CI] restore test_flex_attention tests (#161519)
Reverts #161450 and targets specific subtests to skip on MI200.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161519
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-08-26 19:31:30 +00:00
PyTorch MergeBot
818ba434c7 Revert "Ensure large tensor int32 -> int64 indexing is enabled (#157767)"
This reverts commit fc69c2bc67.

Reverted https://github.com/pytorch/pytorch/pull/157767 on behalf of https://github.com/atalman due to internal failure, sorry will revert ([comment](https://github.com/pytorch/pytorch/pull/157767#issuecomment-3224341111))
2025-08-26 14:12:06 +00:00
drisspg
fc69c2bc67 Ensure large tensor int32 -> int64 indexing is enabled (#157767)
Fixes: #https://github.com/pytorch/pytorch/issues/157446

I think that this delta is worth the switch form block-ptrs especially since they are deprecated

## Perf Summary

A is nightly B is this diff, so `negative` means this diff improves perf

TOP 5 differences
<img width="805" height="754" alt="Screenshot 2025-08-24 at 5 49 49 PM" src="https://github.com/user-attachments/assets/aa359cdf-ee9a-427d-be72-1b9aef6f3115" />

<details>
  <summary><strong>Full perf table (click to expand)</strong></summary>

| attn_type | dtype | shape(B,Hq,M,Hkv,N,D) | TFlops Version A | TFlops Version B |
| --- | --- | --- | --- | --- |
| noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 258.38834144791923 | 258.6353685004612 |
| causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.2192450677751 | 140.12393320464972 |
| alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 122.32683823617003 | 118.51603755647925 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 142.48556906165314 | 137.24259849208627 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 86.59814488695922 | 84.59431398586257 |
| noop | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 288.52679758135764 | 292.9174195871856 |
| causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 172.25541683643277 | 172.94326459828508 |
| alibi | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 164.40864610599826 | 165.035129576335 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 176.54876886433945 | 175.08057670028145 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 125.22491679812626 | 121.06201152859151 |
| noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 339.11952481874283 | 339.0132835601695 |
| causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 227.58583240284406 | 228.21824999409597 |
| alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 185.98569659868966 | 182.32850843255093 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 188.9495725191772 | 180.31385312481657 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 64) | 106.25789530994302 | 106.55084959448476 |
| noop | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 357.6430536888533 | 363.30843452247274 |
| causal | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 262.3241154406613 | 265.73250045488 |
| alibi | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 249.30498953911416 | 249.35928192833785 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 224.74126243851808 | 223.71776504077988 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 16, 2048, 128) | 168.26977014013707 | 165.47991483333809 |
| noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 382.8178701785897 | 384.34752965862685 |
| causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 308.1449710013853 | 311.0653716044644 |
| alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 251.96365252505072 | 243.92283557225903 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 226.69316232745368 | 215.22769268913356 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 153.34142545296405 | 151.9312673939401 |
| noop | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 396.0998000753126 | 398.35036286102473 |
| causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 333.5198415274966 | 344.6354466169716 |
| alibi | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 310.5955933379696 | 305.66347819546 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 260.4012412689896 | 259.758666997307 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 234.13034252182635 | 227.61676497283614 |
| noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 396.17615538477196 | 401.1419104525502 |
| causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 359.98648311998414 | 360.8285563463094 |
| alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 291.97720707257736 | 281.41694809965253 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 250.1703628419691 | 238.556760291579 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 64) | 199.50782826294306 | 191.52327358439223 |
| noop | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 411.0632004785396 | 413.6362648405517 |
| causal | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 382.9404387613185 | 397.74886235657607 |
| alibi | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 357.0998545146633 | 350.5115200772392 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 281.8033924428203 | 281.98601309215843 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 16, 8192, 128) | 282.56595134222135 | 277.4565795466672 |
| noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 408.89838018149516 | 405.14531386840076 |
| causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 396.07662058160264 | 393.4598228299578 |
| alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 317.8822887267849 | 304.754931401036 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 265.8801304948243 | 254.22961974295112 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 64) | 227.87390579965614 | 222.19481980110393 |
| noop | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 427.36821778477025 | 431.3766620314935 |
| causal | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 410.67994346825 | 423.4666944003808 |
| alibi | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 381.1968748374038 | 381.77668006420424 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 292.5540046358546 | 296.5439130720502 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 16, 16384, 128) | 321.04573768858114 | 310.7423616656888 |
| noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 427.46148866769903 | 426.162091037068 |
| causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 419.75580537687347 | 421.88640120274334 |
| alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 337.3208051798903 | 327.4912454675092 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 276.5638854539581 | 262.988360558083 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 64) | 250.82791326036886 | 245.07367032501736 |
| noop | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 435.8055824506086 | 441.8803729460534 |
| causal | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 432.02638235921006 | 450.33161016596273 |
| alibi | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 402.25525939224883 | 393.8564689669916 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 297.5337286675904 | 297.0131881135074 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 16, 32768, 128) | 343.8697037899545 | 329.8194073407783 |
| noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 267.58912366821056 | 256.91606054118375 |
| causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 150.81723692609629 | 146.32172267858743 |
| alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 129.51029293209245 | 122.72144394093334 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 147.627656359087 | 141.68956350566188 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 64) | 87.55100546003591 | 84.91293287692788 |
| noop | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 299.5931492743986 | 305.884253766691 |
| causal | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 179.39026367843837 | 181.64741311605096 |
| alibi | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 173.93547669282367 | 173.23972950980564 |
| sliding_window | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 185.90234171599252 | 182.80844545446686 |
| document_mask | torch.bfloat16 | (2, 16, 1024, 4, 1024, 128) | 128.08176696266082 | 123.27722685662111 |
| noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 340.50674552770664 | 338.9071088484576 |
| causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 225.4438318650432 | 230.22899884832975 |
| alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 194.15123248528312 | 185.02793973094865 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 200.74289714108176 | 191.76606719670647 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 64) | 107.03564946728423 | 106.82432377861258 |
| noop | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 371.31799283918406 | 379.7555394732925 |
| causal | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 275.97762744310455 | 276.71106853992995 |
| alibi | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 261.6648679783462 | 259.4127232060398 |
| sliding_window | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 237.03108223577615 | 233.92710216149527 |
| document_mask | torch.bfloat16 | (2, 16, 2048, 4, 2048, 128) | 172.13926800371152 | 168.74390922407585 |
| noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 381.50199487767276 | 383.9043681999597 |
| causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 307.9748883093411 | 312.2403515462001 |
| alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 251.11319684705438 | 243.17870127827277 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 236.3253127246763 | 223.81250201769552 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 64) | 154.55693991756874 | 153.11360584987685 |
| noop | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 407.11400078586615 | 413.53709886086557 |
| causal | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 348.1705797722622 | 360.09771155957367 |
| alibi | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 321.8593280850388 | 318.2882327401255 |
| sliding_window | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 270.089032013835 | 268.767323026064 |
| document_mask | torch.bfloat16 | (2, 16, 4096, 4, 4096, 128) | 238.07324557907788 | 228.09842078362692 |
| noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 399.8172853171901 | 401.0954526332136 |
| causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 363.4387330438581 | 364.13111024232677 |
| alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 294.1752429133857 | 283.7235663368415 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 256.8389394007649 | 246.91771015606483 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 64) | 199.3378564292656 | 192.40439590901758 |
| noop | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 425.5150965556111 | 430.8190098707553 |
| causal | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 396.00437184073013 | 411.3873625655787 |
| alibi | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 369.92803661607815 | 361.43244467343663 |
| sliding_window | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 293.4277354412933 | 295.2529537595746 |
| document_mask | torch.bfloat16 | (2, 16, 8192, 4, 8192, 128) | 288.0208673072841 | 281.51896404878863 |
| noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 408.3005367220567 | 408.96116482298913 |
| causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 396.90095962766304 | 396.87385456176486 |
| alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 319.0534576137999 | 302.50950358107764 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 270.3334977708081 | 258.8506349486557 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 64) | 227.46824134365394 | 222.23759438128766 |
| noop | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 438.24247309479694 | 437.7975163205371 |
| causal | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 428.34012029699227 | 433.3215899950434 |
| alibi | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 386.52672049728875 | 388.26216893354984 |
| sliding_window | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 302.71976814728083 | 302.3574867306459 |
| document_mask | torch.bfloat16 | (2, 16, 16384, 4, 16384, 128) | 327.39760662780986 | 308.6348428844912 |
| noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 423.31308678262695 | 426.6306972137279 |
| causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 412.6983690923106 | 419.4961977664297 |
| alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 337.41003544742273 | 324.2155049126126 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 278.7755890910794 | 265.9194286636502 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 64) | 251.55678254755364 | 244.8843180141462 |
| noop | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 452.5930781172308 | 457.7117122300742 |
| causal | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 445.05676260348116 | 463.9304535499636 |
| alibi | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 415.78302138389415 | 406.29229555271456 |
| sliding_window | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 308.0311067300895 | 304.91354721414314 |
| document_mask | torch.bfloat16 | (2, 16, 32768, 4, 32768, 128) | 351.43943626809335 | 329.4476923070317 |
| noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 295.1801525813241 | 291.36521287398904 |
| causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 183.23250549178067 | 182.35421238887605 |
| alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 151.56832453117747 | 151.3422139154794 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 171.02111935180432 | 160.72516856727913 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 64) | 74.05765122783826 | 74.5885345035243 |
| noop | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 314.3587394591763 | 319.2938677773619 |
| causal | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 224.57002084153177 | 225.48868542008177 |
| alibi | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.00964804143052 | 215.39576159953486 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 216.1174237618258 | 214.28437413525663 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 16, 1024, 128) | 121.08920423648368 | 119.55813661872644 |
| noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 362.2193857281911 | 360.05005804275936 |
| causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 279.8840217430121 | 279.5437918286659 |
| alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 227.76617121021982 | 222.8655938229316 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 215.43141176970562 | 207.71852284994702 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 64) | 121.35588364218539 | 121.20636565046884 |
| noop | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 365.1545280898012 | 373.37585444987326 |
| causal | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 304.360119952975 | 309.1247297936263 |
| alibi | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 287.2603904544586 | 289.25547903162595 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 257.9852675272418 | 257.59069234098115 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 16, 2048, 128) | 188.35158496670232 | 184.24683960154857 |
| noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 389.9744911369211 | 388.43466897254166 |
| causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 345.9228295166513 | 342.63034895210126 |
| alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 279.56334658247437 | 271.2724375402088 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 245.66477202810066 | 233.49688207371258 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 64) | 170.3270720653187 | 166.23863845657382 |
| noop | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 400.0041140827554 | 402.11182445396497 |
| causal | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 363.64641830327434 | 375.9288663364792 |
| alibi | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 341.5776139573363 | 335.1160003213424 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 281.1811770268521 | 280.21438270014005 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 16, 4096, 128) | 247.78716118997716 | 245.3269825179633 |
| noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 403.794126680488 | 405.2353919019577 |
| causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 387.079178426863 | 385.1461762057035 |
| alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 309.7847188173431 | 298.0443968374749 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 262.4721750159666 | 250.81679725428586 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 64) | 205.70866004479979 | 202.9620839129557 |
| noop | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 413.380982988662 | 418.40270594263103 |
| causal | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 398.450064800682 | 409.6794973994029 |
| alibi | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 372.26297458194466 | 364.44415106552196 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 293.0818569905912 | 292.85172400643984 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 16, 8192, 128) | 296.46717085592087 | 285.76362010612763 |
| noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 419.3186786037592 | 426.08801580934437 |
| causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 408.1648467766632 | 409.4122254207817 |
| alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 329.24396020457345 | 313.5200995121138 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 274.61257504571876 | 255.7801815432177 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 64) | 232.63806001220684 | 230.03020843492314 |
| noop | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 435.0785891054788 | 440.39101804225345 |
| causal | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 424.86925312752817 | 435.18898057396825 |
| alibi | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 393.000417896268 | 395.11543361225256 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 297.7755459218185 | 300.7208114715287 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 16, 16384, 128) | 331.71570861760534 | 318.07127352552885 |
| noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 424.58602747137405 | 425.84897078470715 |
| causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 422.66607285025725 | 423.5524945535485 |
| alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 344.8625760048626 | 331.6793888458635 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 282.0787281511649 | 263.7895634445868 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 64) | 252.7301927385177 | 245.41844170037427 |
| noop | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 437.0658069164588 | 442.9101960063628 |
| causal | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 433.13788271434646 | 452.3873572709863 |
| alibi | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 404.0959191546953 | 396.7077863894884 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 300.45502211883206 | 301.3439134717943 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 16, 32768, 128) | 344.11003202413934 | 330.8897663350314 |
| noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 298.4364205341705 | 291.6793556507056 |
| causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 187.6382133139633 | 191.05409897308772 |
| alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 156.55822078636112 | 154.178925976516 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 173.47765221825162 | 169.30862508068464 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 64) | 74.5885345035243 | 74.52689061607104 |
| noop | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 323.12233826013045 | 328.53889207933514 |
| causal | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 236.75872140126316 | 235.8378325547398 |
| alibi | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 227.17836523816675 | 226.75357076139966 |
| sliding_window | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 224.07209453308036 | 224.07209453308036 |
| document_mask | torch.bfloat16 | (4, 16, 1024, 4, 1024, 128) | 122.85572156047981 | 121.11642183704716 |
| noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 361.3123326658092 | 360.71014086458337 |
| causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 281.5287983927017 | 281.94301754758345 |
| alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 232.7456696285686 | 226.50976826432776 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 221.5612361744038 | 214.96188822837055 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 64) | 121.38311528944315 | 120.85441868178513 |
| noop | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 380.2579019244734 | 389.2520157863988 |
| causal | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 316.95230660496924 | 317.87597790618906 |
| alibi | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 301.07968126657323 | 298.02424098422983 |
| sliding_window | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 267.2240756921594 | 267.16353549228154 |
| document_mask | torch.bfloat16 | (4, 16, 2048, 4, 2048, 128) | 189.82761622494257 | 186.736450261963 |
| noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 389.88665375406805 | 387.9125133037077 |
| causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 348.70619958684887 | 346.6750499749774 |
| alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 280.5472989906087 | 271.22300822012187 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 250.02397620165968 | 241.22532776331445 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 64) | 171.67817496107645 | 166.95679280483972 |
| noop | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 412.626880230807 | 417.60238657950777 |
| causal | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 374.8829313933945 | 389.4448546468815 |
| alibi | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 353.20410434172436 | 345.7072490717473 |
| sliding_window | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 292.51045924209586 | 291.66621022138287 |
| document_mask | torch.bfloat16 | (4, 16, 4096, 4, 4096, 128) | 251.6264062063495 | 248.45110052911542 |
| noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 404.0155784550126 | 401.90546837237514 |
| causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 384.4389015599863 | 386.9684324594344 |
| alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 313.3731284132225 | 298.17074251037894 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 264.19199737284265 | 252.8982463999916 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 64) | 207.03696315185684 | 202.86697323136772 |
| noop | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 428.2436763312506 | 433.45005568619536 |
| causal | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 411.8516531869893 | 428.2753623461049 |
| alibi | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 384.9095037182509 | 372.90888743000744 |
| sliding_window | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 303.2438915629836 | 302.05095952914337 |
| document_mask | torch.bfloat16 | (4, 16, 8192, 4, 8192, 128) | 301.8689122735564 | 285.0363190513223 |
| noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 423.13592231504805 | 420.3991500185611 |
| causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 407.44527331585493 | 408.5064370765247 |
| alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 330.50050996167414 | 316.8763979925965 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 274.6833786307413 | 259.86098862141324 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 64) | 232.24019584158367 | 226.52040268160232 |
| noop | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 444.4596314237808 | 455.99558915752266 |
| causal | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 437.4245561244369 | 455.98275147271966 |
| alibi | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 397.3350686877605 | 397.88875599028063 |
| sliding_window | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 308.53809114394545 | 307.1359822042007 |
| document_mask | torch.bfloat16 | (4, 16, 16384, 4, 16384, 128) | 331.32379843423774 | 316.85293191675646 |
| noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 422.4622274366379 | 425.0407156418684 |
| causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 420.9547052783101 | 430.33779243510276 |
| alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 345.50265346504085 | 332.094855328957 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 280.81715528243365 | 264.6543640282054 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 64) | 252.25635200421783 | 245.46235499490305 |
| noop | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 452.5524207341139 | 461.7512032176736 |
| causal | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 445.2316469907137 | 464.4523799578466 |
| alibi | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 416.87264016717023 | 409.17124592157046 |
| sliding_window | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 309.42579489389846 | 307.9734464665731 |
| document_mask | torch.bfloat16 | (4, 16, 32768, 4, 32768, 128) | 350.50782004300623 | 330.98959545427294 |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157767
Approved by: https://github.com/Skylion007
2025-08-25 22:51:00 +00:00
Angel Li
3a4140bf8e [FlexAttention] fixing learnable bias assertion error in inductor (#161170)
Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677)

We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts.

In this PR, we register the buffers with the base graph to solve this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170
Approved by: https://github.com/drisspg
2025-08-23 06:24:22 +00:00
drisspg
a6bc296207 [FlexAttention] Update the guard semantics for divisibility (#159884)
We don't add guards unless we know (and another guard has ensured this) that this is a safe optimization

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159884
Approved by: https://github.com/Chillee
2025-08-06 23:12:44 +00:00
PyTorch MergeBot
3e5e094615 Revert "Fix large_tensor_test skipping cpu (#158617)"
This reverts commit debc0591b8.

Reverted https://github.com/pytorch/pytorch/pull/158617 on behalf of https://github.com/ZainRizvi due to Sorry but this seems to be breaking trunk. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/16631113381/job/47062415099) [HUD commit link](debc0591b8) ([comment](https://github.com/pytorch/pytorch/pull/158617#issuecomment-3138387762))
2025-07-31 02:57:22 +00:00
drisspg
debc0591b8 Fix large_tensor_test skipping cpu (#158617)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158617
Approved by: https://github.com/BoyuanFeng
2025-07-30 18:48:07 +00:00
NikhilAPatel
21a95bdf7c [Inductor] [Triton] Enabling TMA for flex-attention for supported device types (#157822)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157822
Approved by: https://github.com/drisspg
ghstack dependencies: #159123
2025-07-25 23:45:26 +00:00
drisspg
a00cd8cf25 Add a way to disable compile for debugging flex-attention (#158534)
Finally got around to doing this, this flag lets us do:

```Python

#!/usr/bin/env python3
"""
FlexAttention Debug: Using breakpoints and unwrap
"""

import torch
import torch.nn.attention.flex_attention as fa

unwrap = torch._C._functorch.get_unwrapped

def score_mod(score, batch, head, q_idx, kv_idx):
    # Set breakpoint here to debug
    breakpoint()

    # In debugger, unwrap to see actual tensor values:
    # >>> actual_score = unwrap(unwrap(unwrap(unwrap(score))))
    # >>> actual_batch = unwrap(batch)
    # >>> actual_head = unwrap(head)
    # >>> actual_q_idx = unwrap(q_idx)
    # >>> actual_kv_idx = unwrap(kv_idx)
    # >>> print(actual_score)
    # >>> print(f"q_idx: {actual_q_idx}, kv_idx: {actual_kv_idx}")

    return torch.where(q_idx >= kv_idx, score, torch.tensor(float('-inf')))

def main():
    # Enable debug mode
    fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True

    # Small example
    B, H, S, D = 1, 2, 4, 8
    q = torch.randn(B, H, S, D)
    k = torch.randn(B, H, S, D)
    v = torch.randn(B, H, S, D)

    # Run - will hit breakpoint
    output = fa.flex_attention(q, k, v, score_mod=score_mod)

    # Disable debug mode
    fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False

if __name__ == "__main__":
    main()

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158534
Approved by: https://github.com/Chillee, https://github.com/zou3519
2025-07-18 05:33:45 +00:00
drisspg
8c928372b3 Make Q Indices optional (#157997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157997
Approved by: https://github.com/BoyuanFeng, https://github.com/Chillee
2025-07-12 00:16:20 +00:00
drisspg
987314aa96 Split batch-num-heads grid dim between y and z (#157745)
for #157018

doesn't totally fix the problem but should help alot

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157745
Approved by: https://github.com/Chillee
2025-07-08 05:17:43 +00:00
Xuehai Pan
17687eb792 [BE][4/6] fix typos in test/ (test/inductor/) (#157638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157638
Approved by: https://github.com/yewentao256, https://github.com/jansel
2025-07-06 06:34:25 +00:00
Xuehai Pan
f5e6e52f25 [BE][PYFMT] migrate PYFMT for test/inductor/ to ruff format (#148186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148186
Approved by: https://github.com/jansel
2025-06-24 11:12:11 +00:00
Brian Hirsh
ccc6279b40 flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo (#151719)
This is enough to get @XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.

**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151719
Approved by: https://github.com/ydwu4

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-06-18 07:02:04 +00:00
drisspg
80703ca332 [FlexAttention] Allow dispatch to SAC for flex (#150080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150080
Approved by: https://github.com/zou3519
2025-06-05 04:34:27 +00:00