pytorch/torch
Daniel Vega-Myhre 881a598a1e [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357)
Fixes #147336

## Context

NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.

Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.

To summarize:

In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.

This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).

i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.

## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs

## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.

Before fix:

```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```

After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
2025-05-15 02:41:38 +00:00
..
_awaits
_C [inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353) 2025-05-15 02:33:57 +00:00
_C_flatbuffer
_custom_op
_decomp Fix torch.isin decomposition for scalar inputs (#153216) 2025-05-09 20:26:25 +00:00
_dispatch
_dynamo [dynamo][compile-time] Compute logging related flags once (#153426) 2025-05-14 21:19:06 +00:00
_export [export] support functools.partial forward (non-strict) (#153408) 2025-05-13 23:30:13 +00:00
_functorch [compile-time traces] Profile large missing gaps in compile time (#151256) 2025-05-13 14:44:51 +00:00
_higher_order_ops Revert "[export][cond] support merging constant ints as unbacked symint (#152742)" 2025-05-12 23:06:33 +00:00
_inductor don't run triton mm for k<32 (#153550) 2025-05-15 02:36:44 +00:00
_lazy
_library Add torch._C.Tag.needs_contiguous_strides (#152859) 2025-05-08 04:49:59 +00:00
_logging
_numpy Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_prims
_prims_common consolidate guard_or_x and definitely_x (#152463) 2025-05-02 18:08:11 +00:00
_refs [dynamic shapes] use try-catch instead of guard_or_true for reshape_view_helper (#152638) 2025-05-06 00:54:24 +00:00
_strobelight
_subclasses [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
_vendor
accelerator Add torch.accelerator.device_index as accelerator's device switch context (#148864) 2025-04-25 09:45:25 +00:00
amp
ao Fix more URLs (#153277) 2025-05-14 16:23:50 +00:00
autograd [autograd][docs] Add more details on why save_for_backward is important in extending autograd note (#153005) 2025-05-09 16:36:57 +00:00
backends Revert "refine fp32 precision api (#125888)" 2025-05-11 00:35:46 +00:00
compiler [MegaCache] Return None on no compilation (#151921) 2025-04-23 04:32:06 +00:00
contrib
cpu [device_mesh] improve device selection logic (#150897) 2025-05-14 06:29:16 +00:00
csrc [inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353) 2025-05-15 02:33:57 +00:00
cuda make use_mem_pool threadlocal (#153356) 2025-05-13 00:16:07 +00:00
distributed [Ez][BE]: Remove accidental classvar (#153540) 2025-05-14 21:55:56 +00:00
distributions Fix support of MixtureSameFamily [bugfix]. (#151317) 2025-05-14 19:24:36 +00:00
export [export] Support no inputs in unflattened module (#153474) 2025-05-14 18:45:47 +00:00
fft
func
futures
fx Add skip_dtype_check_in_meta_registrations config to torch/fx/experimental/_config (#153513) 2025-05-14 09:14:11 +00:00
jit Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
legacy
lib
linalg Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
masked [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
monitor
mps
mtia [Kineto] Enable OOM observer (#152160) 2025-04-27 15:56:44 +00:00
multiprocessing
nativert [nativert] Move Placement to pytorch core (#152953) 2025-05-14 15:26:54 +00:00
nested [Torch][NT] Fix NestedTensor contiguous check condition. (#153237) (#153529) 2025-05-14 17:15:48 +00:00
nn [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357) 2025-05-15 02:41:38 +00:00
onnx [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
optim Add load_state_dict hint doc about invoke order work with lr_scheduler (#149942) 2025-05-15 01:07:36 +00:00
package [BE]: Enable ruff YTT linter for Python version checks (#153547) 2025-05-14 21:09:16 +00:00
profiler
quantization
signal
sparse Revert "has_triton: Use the device interface for detecting Triton availability (#139171)" 2025-05-10 14:46:23 +00:00
special
testing Add TEST_HPU flag to set device type (#153461) 2025-05-14 19:31:40 +00:00
utils [BE]: Enable ruff YTT linter for Python version checks (#153547) 2025-05-14 21:09:16 +00:00
xpu Correct torch.xpu.is_bf16_supported return False if no XPU detected (#152317) 2025-05-06 10:03:17 +00:00
__config__.py
__future__.py
__init__.py Detect NVSHMEM location (#153010) 2025-05-07 23:35:04 +00:00
_appdirs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_classes.py
_compile.py
_custom_ops.py
_deploy.py
_environment.py
_guards.py [dynamo][compile-time] Compute logging related flags once (#153426) 2025-05-14 21:19:06 +00:00
_jit_internal.py
_linalg_utils.py
_lobpcg.py Fixed rerr computation in lobpcg (#152789) 2025-05-08 12:22:31 +00:00
_lowrank.py
_meta_registrations.py API change for new enum in cusparseltsplitkmode-t for cusparseLT 0.7.0+ (#150536) 2025-05-14 23:36:53 +00:00
_namedtensor_internals.py
_ops.py Introduce unsafe way to mark functions as cacheable (#151603) 2025-04-21 17:37:38 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_tensor_str.py
_tensor.py Avoid triggering ignored requires_grad warning in our code (#152686) 2025-05-05 23:56:40 +00:00
_thread_safe_fork.py
_torch_docs.py Fix the basic description of torch.min(), torch.max(), torch.all(), torch.any() (#152658) 2025-05-08 22:59:14 +00:00
_utils_internal.py [reland] Add graph module runtime asserts to AOTI (#153182) 2025-05-09 22:56:19 +00:00
_utils.py
_VF.py
_vmap_internals.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_weights_only_unpickler.py
CMakeLists.txt
custom_class_detail.h
custom_class.h
extension.h
functional.py
hub.py
library.h Overload Library::def rather than templating it (#151626) 2025-04-18 22:51:16 +00:00
library.py fix spammy library deinit errors when user passes an invalid TORCH_LOGS argument (#151678) 2025-04-22 20:13:52 +00:00
overrides.py [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs (#150812) 2025-04-18 01:53:26 +00:00
py.typed
quasirandom.py
random.py Update description for torch.random.fork_rng (#151881) 2025-04-23 16:59:29 +00:00
return_types.py
script.h
serialization.py
storage.py
torch_version.py
types.py
version.py.tpl