mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #159247 Issue 1: Accuracy Problem with Non-Divisible KV Sequences --------------------------------------------------------- ### Background Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy. ### Root Cause The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue. ### Solution * **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch ### Files Modified * `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod. * `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests * `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests Issue 2: Invalid Memory Access (IMA) in Triton Kernels ------------------------------------------------------ ### Background The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes. ### Root Cause * Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`) * With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values * This caused out-of-bounds memory access violations ### Solution Implemented boundary checks with early exit: 1. **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py` * Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template 2. **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja` * Boundary checks before accessing `kv_indices` in both normal and full blocks * Idle CTAs with invalid `indices_idx` skip computation entirely This prevents invalid memory access while reducing wasted computation on idle thread blocks. Testing & Validation -------------------- ### Accuracy Tests * Added comprehensive test cases covering KV sequences not divisible by block sizes * Verified output matches standard attention for various sequence length combinations ### Sanitizer Results `========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors` **Before**: More than 13720 invalid memory access errors with sanitizers **After**: Clean execution with 0 errors Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861 Approved by: https://github.com/BoyuanFeng |
||
|---|---|---|
| .. | ||
| ao/sparsity | ||
| autograd | ||
| backends/xeon | ||
| benchmark_utils | ||
| bottleneck_test | ||
| compiled_autograd_skips | ||
| cpp | ||
| cpp_api_parity | ||
| cpp_extensions | ||
| custom_backend | ||
| custom_operator | ||
| distributed | ||
| distributions | ||
| dynamo | ||
| dynamo_expected_failures | ||
| dynamo_skips | ||
| error_messages | ||
| expect | ||
| export | ||
| forward_backward_compatibility | ||
| functorch | ||
| fx | ||
| higher_order_ops | ||
| inductor | ||
| inductor_expected_failures | ||
| inductor_skips | ||
| jit | ||
| jit_hooks | ||
| lazy | ||
| mobile | ||
| nn | ||
| onnx | ||
| optim | ||
| package | ||
| profiler | ||
| quantization | ||
| scripts | ||
| strobelight/examples | ||
| test_img | ||
| torch_np | ||
| typing | ||
| xpu | ||
| _test_bazel.py | ||
| allowlist_for_publicAPI.json | ||
| bench_mps_ops.py | ||
| conftest.py | ||
| create_dummy_torchscript_model.py | ||
| HowToWriteTestsUsingFileCheck.md | ||
| linear.py | ||
| load_torchscript_model.py | ||
| minioptest_failures_dict.json | ||
| mkl_verbose.py | ||
| mkldnn_verbose.py | ||
| pytest_shard_custom.py | ||
| run_doctests.sh | ||
| run_test.py | ||
| simulate_nccl_errors.py | ||
| slow_tests.json | ||
| test_accelerator.py | ||
| test_ao_sparsity.py | ||
| test_appending_byte_serializer.py | ||
| test_autocast.py | ||
| test_autograd_fallback.py | ||
| test_autograd.py | ||
| test_autoload.py | ||
| test_binary_ufuncs.py | ||
| test_bundled_images.py | ||
| test_bundled_inputs.py | ||
| test_ci_sanity_check_fail.py | ||
| test_comparison_utils.py | ||
| test_compile_benchmark_util.py | ||
| test_complex.py | ||
| test_content_store.py | ||
| test_cpp_api_parity.py | ||
| test_cpp_extensions_aot.py | ||
| test_cpp_extensions_jit.py | ||
| test_cpp_extensions_mtia_backend.py | ||
| test_cpp_extensions_stream_and_event.py | ||
| test_cuda_expandable_segments.py | ||
| test_cuda_multigpu.py | ||
| test_cuda_nvml_based_avail.py | ||
| test_cuda_primary_ctx.py | ||
| test_cuda_sanitizer.py | ||
| test_cuda_trace.py | ||
| test_cuda.py | ||
| test_custom_ops.py | ||
| test_dataloader.py | ||
| test_datapipe.py | ||
| test_decomp.py | ||
| test_determination.py | ||
| test_dispatch.py | ||
| test_dlpack.py | ||
| test_dynamic_shapes.py | ||
| test_expanded_weights.py | ||
| test_extension_utils.py | ||
| test_fake_tensor.py | ||
| test_file_check.py | ||
| test_flop_counter.py | ||
| test_foreach.py | ||
| test_function_schema.py | ||
| test_functional_autograd_benchmark.py | ||
| test_functional_optim.py | ||
| test_functionalization_of_rng_ops.py | ||
| test_functionalization.py | ||
| test_futures.py | ||
| test_fx_experimental.py | ||
| test_fx_passes.py | ||
| test_fx_reinplace_pass.py | ||
| test_fx.py | ||
| test_hop_infra.py | ||
| test_hub.py | ||
| test_import_stats.py | ||
| test_indexing.py | ||
| test_itt.py | ||
| test_jit_autocast.py | ||
| test_jit_disabled.py | ||
| test_jit_fuser_legacy.py | ||
| test_jit_fuser_te.py | ||
| test_jit_fuser.py | ||
| test_jit_legacy.py | ||
| test_jit_llga_fuser.py | ||
| test_jit_profiling.py | ||
| test_jit_simple.py | ||
| test_jit_string.py | ||
| test_jit.py | ||
| test_jiterator.py | ||
| test_kernel_launch_checks.py | ||
| test_legacy_vmap.py | ||
| test_license.py | ||
| test_linalg.py | ||
| test_logging.py | ||
| test_masked.py | ||
| test_maskedtensor.py | ||
| test_matmul_cuda.py | ||
| test_meta.py | ||
| test_metal.py | ||
| test_mkl_verbose.py | ||
| test_mkldnn_fusion.py | ||
| test_mkldnn_verbose.py | ||
| test_mkldnn.py | ||
| test_mobile_optimizer.py | ||
| test_model_exports_to_core_aten.py | ||
| test_module_tracker.py | ||
| test_modules.py | ||
| test_monitor.py | ||
| test_mps.py | ||
| test_multiprocessing_spawn.py | ||
| test_multiprocessing.py | ||
| test_namedtensor.py | ||
| test_namedtuple_return_api.py | ||
| test_native_functions.py | ||
| test_native_mha.py | ||
| test_nestedtensor.py | ||
| test_nn.py | ||
| test_nnapi.py | ||
| test_numa_binding.py | ||
| test_numba_integration.py | ||
| test_numpy_interop.py | ||
| test_openmp.py | ||
| test_openreg.py | ||
| test_ops_fwd_gradients.py | ||
| test_ops_gradients.py | ||
| test_ops_jit.py | ||
| test_ops.py | ||
| test_optim.py | ||
| test_out_dtype_op.py | ||
| test_overrides.py | ||
| test_package.py | ||
| test_per_overload_api.py | ||
| test_prims.py | ||
| test_proxy_tensor.py | ||
| test_pruning_op.py | ||
| test_public_bindings.py | ||
| test_python_dispatch.py | ||
| test_pytree.py | ||
| test_quantization.py | ||
| test_reductions.py | ||
| test_rename_privateuse1_to_existing_device.py | ||
| test_scatter_gather_ops.py | ||
| test_schema_check.py | ||
| test_segment_reductions.py | ||
| test_serialization.py | ||
| test_set_default_mobile_cpu_allocator.py | ||
| test_shape_ops.py | ||
| test_show_pickle.py | ||
| test_sort_and_select.py | ||
| test_sparse_csr.py | ||
| test_sparse_semi_structured.py | ||
| test_sparse.py | ||
| test_spectral_ops.py | ||
| test_stateless.py | ||
| test_static_runtime.py | ||
| test_subclass.py | ||
| test_sympy_utils.py | ||
| test_tensor_creation_ops.py | ||
| test_tensorboard.py | ||
| test_tensorexpr_pybind.py | ||
| test_tensorexpr.py | ||
| test_testing.py | ||
| test_throughput_benchmark.py | ||
| test_torch.py | ||
| test_transformers_privateuse1.py | ||
| test_transformers.py | ||
| test_type_hints.py | ||
| test_type_info.py | ||
| test_type_promotion.py | ||
| test_typing.py | ||
| test_unary_ufuncs.py | ||
| test_utils_config_module.py | ||
| test_utils_filelock.py | ||
| test_utils.py | ||
| test_view_ops.py | ||
| test_vulkan.py | ||
| test_weak.py | ||
| test_xnnpack_integration.py | ||
| test_xpu.py | ||