pytorch/torch/nn
Daniel Vega-Myhre 7e16cb99b6 [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357)
Fixes #147336

## Context

NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.

Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.

To summarize:

In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.

This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).

i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.

## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs

## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.

Before fix:

```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```

After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
2025-05-16 04:56:50 +00:00
..
attention [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357) 2025-05-16 04:56:50 +00:00
backends
intrinsic
modules [dynamo] Emit warning on global module hooks when calling using output of torch.compile(module) (#152740) 2025-05-14 17:03:59 +00:00
parallel [ddp] propagate use_python_reducer to C++ reducer (#152735) 2025-05-16 01:38:03 +00:00
qat
quantizable
quantized
utils [Easy] Optimize clip_grad param description (#151532) 2025-04-17 16:47:38 +00:00
__init__.py
_reduction.py
common_types.py
cpp.py
functional.py Add pad limit of avg_poolnd and AvgPoolnd (#152680) 2025-05-04 17:25:22 +00:00
functional.pyi.in [BE] Add __all__ to torch/nn/functional.pyi and torch/return_types.pyi (#150729) 2025-05-15 19:01:57 +00:00
grad.py
init.py
parameter.py
parameter.pyi