pytorch/test
Tianren Gao 2fed4fb464 [FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861)
Fixes #159247
Issue 1: Accuracy Problem with Non-Divisible KV Sequences
---------------------------------------------------------

### Background

Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy.

### Root Cause
The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue.

### Solution

*   **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch

### Files Modified
*    `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod.
*   `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests
*   `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests

Issue 2: Invalid Memory Access (IMA) in Triton Kernels
------------------------------------------------------

### Background

The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes.

### Root Cause

*   Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`)
*   With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values
*   This caused out-of-bounds memory access violations

### Solution

Implemented boundary checks with early exit:

1.  **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py`

    *   Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template
2.  **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja`

    *   Boundary checks before accessing `kv_indices` in both normal and full blocks
    *   Idle CTAs with invalid `indices_idx` skip computation entirely

This prevents invalid memory access while reducing wasted computation on idle thread blocks.

Testing & Validation
--------------------

### Accuracy Tests

*   Added comprehensive test cases covering KV sequences not divisible by block sizes
*   Verified output matches standard attention for various sequence length combinations

### Sanitizer Results

`========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors`

**Before**: More than 13720 invalid memory access errors with sanitizers
**After**: Clean execution with 0 errors

Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861
Approved by: https://github.com/BoyuanFeng
2025-08-30 04:50:23 +00:00
..
ao/sparsity Enable more nightly tests on s390x (#160893) 2025-08-28 22:20:55 +00:00
autograd
backends/xeon
benchmark_utils [BE][3/6] fix typos in test/ (#157637) 2025-07-17 12:08:33 +00:00
bottleneck_test
compiled_autograd_skips
cpp [TorchScript] ProfilingExecutor - RemoveProfileNodesAndSpecializeTypes None handling (#161538) 2025-08-27 23:12:15 +00:00
cpp_api_parity
cpp_extensions Add new_zeros dtype variant to the shim and as a stable op (#161597) 2025-08-28 13:57:24 +00:00
custom_backend
custom_operator Using the latest torch.library.register_fake API instead of torch.library.impl_abstract (#158839) 2025-07-25 02:37:30 +00:00
distributed [SymmMEM] Move AsyncTP tests to a seperate test class (#161820) 2025-08-30 00:40:40 +00:00
distributions [BE] fix remaining flake8 v7 warnings (#159044) 2025-07-25 02:56:34 +00:00
dynamo Revert "kill allow_complex_guards_as_runtime_asserts (#160198)" 2025-08-28 22:50:37 +00:00
dynamo_expected_failures [OrderedDict] Implement OrderedDict.popitem(last=...) (#155153) 2025-08-27 15:46:40 +00:00
dynamo_skips Move non inductor workflows to Python 3.9 -> 3.10 (#161182) 2025-08-27 02:32:24 +00:00
error_messages
expect [cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282) 2025-08-08 22:22:48 +00:00
export [export] Support complex constant in serde (#161517) 2025-08-29 08:13:21 +00:00
forward_backward_compatibility Fused RMSNorm implementation (#153666) 2025-07-22 22:25:44 +00:00
functorch remove old while_loop_schema_gen test (#161202) 2025-08-22 18:22:29 +00:00
fx Separate provenance tracking to different levels (#160383) 2025-08-15 04:59:35 +00:00
higher_order_ops [invoke_subgraph][inductor] Thread graphsafe rng input states for hops (#160713) 2025-08-21 20:41:29 +00:00
inductor [FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861) 2025-08-30 04:50:23 +00:00
inductor_expected_failures
inductor_skips
jit [cuDNN][TF32] Account for TF32 in test_super_resolution_cuda (#161662) 2025-08-28 08:42:34 +00:00
jit_hooks
lazy [BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556) 2025-07-29 03:26:09 +00:00
mobile [BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556) 2025-07-29 03:26:09 +00:00
nn [cuDNN][convolution] remove redundant conv3d 64bit test (#161177) 2025-08-25 15:01:05 +00:00
onnx [ONNX] Fix lower opset version support in dynamo=True (#161056) 2025-08-23 05:04:36 +00:00
optim Fix SequentialLR deprecate warning about invoke step(epoch) (#149392) 2025-08-29 11:45:11 +00:00
package [Torch Package] Make get names of OrderedImporters support fallback to importers (#155743) 2025-08-06 02:26:10 +00:00
profiler Update pybind11 submodule to 3.0.1 (#160754) 2025-08-27 21:15:01 +00:00
quantization [ROCm] fix numpy version detection and adjust fudge_factors for MI355 (#161429) 2025-08-28 19:32:09 +00:00
scripts [BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556) 2025-07-29 03:26:09 +00:00
strobelight/examples
test_img
torch_np [BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556) 2025-07-29 03:26:09 +00:00
typing More testing of Python arithmetic operators between tensors and scalars (see 157266) (#157632) 2025-07-05 17:48:27 +00:00
xpu Enable _int_mm on Intel GPU (#157769) 2025-08-02 05:16:01 +00:00
_test_bazel.py
allowlist_for_publicAPI.json remove guard_or_x from allowlist_for_publicAPI (#159181) 2025-07-26 01:22:17 +00:00
bench_mps_ops.py [BE] Remove macos-13 guard from bench_mps_ops (#159732) 2025-08-03 20:53:58 +00:00
conftest.py
create_dummy_torchscript_model.py
HowToWriteTestsUsingFileCheck.md
linear.py
load_torchscript_model.py
minioptest_failures_dict.json
mkl_verbose.py
mkldnn_verbose.py
pytest_shard_custom.py
run_doctests.sh
run_test.py Enable more nightly tests on s390x (#160893) 2025-08-28 22:20:55 +00:00
simulate_nccl_errors.py
slow_tests.json Update slow tests (#160870) 2025-08-18 11:53:41 +00:00
test_accelerator.py Add UT for torch.accelerator memory-related API (#155200) 2025-08-08 17:41:22 +00:00
test_ao_sparsity.py
test_appending_byte_serializer.py
test_autocast.py
test_autograd_fallback.py Fix TestAutogradFallback flaky tests under Dynamo: migrate to lib._destroy() (#159443) 2025-07-30 19:30:55 +00:00
test_autograd.py Revert "[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (#160934)" 2025-08-28 17:56:36 +00:00
test_autoload.py
test_binary_ufuncs.py Revert "handling special case for pow(3) for GPU (#157537)" 2025-08-19 22:57:45 +00:00
test_bundled_images.py
test_bundled_inputs.py
test_ci_sanity_check_fail.py
test_comparison_utils.py
test_compile_benchmark_util.py
test_complex.py
test_content_store.py
test_cpp_api_parity.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_cpp_extensions_aot.py [build] modernize build-frontend: python setup.py develop/install -> [uv ]pip install --no-build-isolation [-e ]. (#156027) 2025-07-09 11:24:27 +00:00
test_cpp_extensions_jit.py Add ScalarType -> shim conversion, add stable::Tensor.scalar_type (#160557) 2025-08-19 22:13:47 +00:00
test_cpp_extensions_mtia_backend.py
test_cpp_extensions_stream_and_event.py
test_cuda_expandable_segments.py
test_cuda_multigpu.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_cuda_nvml_based_avail.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_cuda_primary_ctx.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_cuda_sanitizer.py
test_cuda_trace.py
test_cuda.py Revert "Generalize torch._C._set_allocator_settings to be generic (#156175)" (#161626) 2025-08-27 21:37:14 +00:00
test_custom_ops.py Add utility to get computed kernel in torch.library (#158393) 2025-08-13 21:00:59 +00:00
test_dataloader.py skip XPU for dataloader CPU only unit test (#159811) 2025-08-05 03:44:01 +00:00
test_datapipe.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_decomp.py Fix full_like decomposition to preserve strides (#158898) 2025-07-25 20:21:36 +00:00
test_determination.py
test_dispatch.py
test_dlpack.py [Testing] Add MPS to NATIVE_DEVICES (#153835) 2025-08-05 18:57:35 +00:00
test_dynamic_shapes.py Revert "[dynamic shapes] unbacked-safe slicing (#157944)" 2025-08-22 20:48:46 +00:00
test_expanded_weights.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_extension_utils.py
test_fake_tensor.py [MTIA] Allow users who know what they are doing to ignore all device mismatches in tracing and take a preferred device. (#159931) 2025-08-07 22:37:15 +00:00
test_file_check.py
test_flop_counter.py
test_foreach.py Fix requires_cuda to requires_cuda_and_triton (#160222) 2025-08-10 07:05:52 +00:00
test_function_schema.py
test_functional_autograd_benchmark.py
test_functional_optim.py
test_functionalization_of_rng_ops.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_functionalization.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_futures.py
test_fx_experimental.py [fx] fix split_module with symint (#160093) 2025-08-13 05:50:15 +00:00
test_fx_passes.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_fx_reinplace_pass.py [dynamic shapes] avoid unnecessary slices (#157528) 2025-07-10 06:34:46 +00:00
test_fx.py Extend torch function support to ALL arguments, not just scalar type (but not insides of list) (#145089) 2025-08-07 23:43:53 +00:00
test_hop_infra.py
test_hub.py
test_import_stats.py
test_indexing.py Revert "Fix index_add for int64 input + zerodim index (#161511)" 2025-08-27 15:38:11 +00:00
test_itt.py
test_jit_autocast.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_jit_disabled.py
test_jit_fuser_legacy.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_jit_fuser_te.py Remove tensorexpr tests (#158928) 2025-08-09 02:21:22 +00:00
test_jit_fuser.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_jit_legacy.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_jit_llga_fuser.py
test_jit_profiling.py
test_jit_simple.py
test_jit_string.py
test_jit.py Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)" 2025-08-04 20:37:39 +00:00
test_jiterator.py
test_kernel_launch_checks.py
test_legacy_vmap.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_license.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_linalg.py make einsum produce contiguous inputs in more cases (#161755) 2025-08-29 18:50:46 +00:00
test_logging.py
test_masked.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_maskedtensor.py Fix MaskedTensor to device ignored mask (#151205) 2025-07-21 21:44:49 +00:00
test_matmul_cuda.py [cuBLASLt][FP8] cuBLASLt appears to support float8 rowwise-scaling on H100 (#161305) 2025-08-28 17:04:25 +00:00
test_meta.py [easy][test] Add repeat_interleave opinfo that exercises binary search fusion (#161445) 2025-08-26 12:32:24 +00:00
test_metal.py
test_mkl_verbose.py
test_mkldnn_fusion.py
test_mkldnn_verbose.py
test_mkldnn.py Enable TF32 as fp32 internal precision for matmul/linear/conv (#157520) 2025-07-17 08:57:34 +00:00
test_mobile_optimizer.py
test_model_exports_to_core_aten.py
test_module_tracker.py
test_modules.py
test_monitor.py
test_mps.py [MPS] sparse add unary funcs + add for sparse tensors (#160839) 2025-08-30 01:09:00 +00:00
test_multiprocessing_spawn.py Test multiprocessing spawn timing fix (#160672) 2025-08-15 00:11:55 +00:00
test_multiprocessing.py
test_namedtensor.py
test_namedtuple_return_api.py
test_native_functions.py
test_native_mha.py
test_nestedtensor.py [cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282) 2025-08-08 22:22:48 +00:00
test_nn.py [MPS] Add grid_sampler_3d for MPS (#160541) 2025-08-15 16:19:25 +00:00
test_nnapi.py
test_numa_binding.py Allow parallel start NUMA binding (#161576) 2025-08-28 01:15:58 +00:00
test_numba_integration.py
test_numpy_interop.py Throw invalid_argument instead of RuntimeError when parameters exceed… (#158267) 2025-07-25 23:49:46 +00:00
test_openmp.py
test_openreg.py [OpenReg] Add OSX/Windows Support for OpenReg (#159441) 2025-08-25 08:03:27 +00:00
test_ops_fwd_gradients.py
test_ops_gradients.py
test_ops_jit.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_ops.py Relax unclaimed successes in dtype op tests when running under TEST_WITH_DYNAMO/TEST_WITH_INDUCTOR (#159976) 2025-08-07 02:38:45 +00:00
test_optim.py [muon] Introduce Muon optimizer to PyTorch (#160213) 2025-08-24 08:03:04 +00:00
test_out_dtype_op.py
test_overrides.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_package.py
test_per_overload_api.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_prims.py
test_proxy_tensor.py Revert "[dynamic shapes] unbacked-safe slicing (#157944)" 2025-08-22 20:48:46 +00:00
test_pruning_op.py
test_public_bindings.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_python_dispatch.py Revert "flip the list-as-tuple behavior for short lists (#160794)" 2025-08-21 16:33:30 +00:00
test_pytree.py
test_quantization.py Remove pytorch quant docs since we are moving to torchao (#157766) 2025-07-11 03:21:47 +00:00
test_reductions.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_rename_privateuse1_to_existing_device.py [Device] Add support for PrivateUse1 device type in parse_type function (#157609) 2025-07-17 01:27:44 +00:00
test_scatter_gather_ops.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_schema_check.py [inductor] slow test some Windows UTs. (#160267) 2025-08-10 18:35:42 +00:00
test_segment_reductions.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_serialization.py added class or module info for functions blocked by weight-only load (#159935) 2025-08-12 20:52:25 +00:00
test_set_default_mobile_cpu_allocator.py
test_shape_ops.py
test_show_pickle.py
test_sort_and_select.py Add dtype checks in meta dispatch for various ordering ops (#159556) 2025-08-14 17:06:27 +00:00
test_sparse_csr.py [BE] remove torch deploy - conditionals (#158288) 2025-07-29 17:40:49 +00:00
test_sparse_semi_structured.py [BE] fix remaining flake8 v7 warnings (#159044) 2025-07-25 02:56:34 +00:00
test_sparse.py [MPS] sparse add unary funcs + add for sparse tensors (#160839) 2025-08-30 01:09:00 +00:00
test_spectral_ops.py
test_stateless.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_static_runtime.py
test_subclass.py
test_sympy_utils.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_tensor_creation_ops.py Revert "Use vectorized stores for all dtypes (#161649)" 2025-08-30 03:13:40 +00:00
test_tensorboard.py
test_tensorexpr_pybind.py
test_tensorexpr.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_testing.py [Fix XPU CI][Inductor UT] Fix test cases broken by community. (#160403) 2025-08-19 00:54:51 +00:00
test_throughput_benchmark.py
test_torch.py [BE] Move indexing tests to test_indexing (#160994) 2025-08-21 00:42:55 +00:00
test_transformers_privateuse1.py Refactor and Improve the OpenReg Module (#158090) 2025-07-15 08:10:05 +00:00
test_transformers.py [ROCm] fix numpy version detection and adjust fudge_factors for MI355 (#161429) 2025-08-28 19:32:09 +00:00
test_type_hints.py
test_type_info.py
test_type_promotion.py [BE] Raise ValueError from torch.cat meta func (#158249) 2025-07-20 23:49:18 +00:00
test_typing.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_unary_ufuncs.py [inductor] slow test some Windows UTs. (#160267) 2025-08-10 18:35:42 +00:00
test_utils_config_module.py
test_utils_filelock.py
test_utils.py [docs] Decorator to create a deprecation warning (#155127) 2025-06-25 18:09:04 +00:00
test_view_ops.py unify broadcast_shapes functions and avoid duplicates (#160251) 2025-08-16 00:54:32 +00:00
test_vulkan.py
test_weak.py [BE][2/6] fix typos in test/ (test/test_*.py) (#157636) 2025-07-09 11:02:23 +00:00
test_xnnpack_integration.py
test_xpu.py Generalize support of background thread in pinned allocator (#160505) 2025-08-14 02:22:39 +00:00