mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[FlexAttn] Fix Paged Attention Accuracy via Upper Mask Mod and Prevent Invalid Memory Access (#160861)
Fixes #159247 Issue 1: Accuracy Problem with Non-Divisible KV Sequences --------------------------------------------------------- ### Background Paged attention in flex decoding produced inaccurate results when KV sequence length is not divisible by block size. For example, when `KV_S = 64` and `block_size = 128`, the output didn't match standard attention accuracy. ### Root Cause The current paged attention does not apply upper mask mod when converting from logical to physical mask mod. Instead, it uses a noop_mask by default which makes all the values unmasked, leading to an accuracy mismatch. Adding a upper mask mod according to the origin actual kv_len (64 in this test case) resolves the issue. ### Solution * **Applied proper upper bound masking**: Updated all calls to `convert_logical_block_mask` to pass `kv_len` as a tensor with proper shape `[B, KV_S]` to provide information of actual batched KV sequence length. The function now correctly applies upper bound checks using the actual KV sequence lengths for each batch ### Files Modified * `torch/nn/attention/experimental/_paged_attention.py`: Added `kv_len` parameter as a tensor to `get_mask_mod` and applied upper mask to the new mask mod. * `test/inductor/test_flex_attention.py`: Fixed all related `kv_len` parameter call in the tests * `test/inductor/test_flex_decoding.py`: Fixed all related `kv_len` parameter call in the tests Issue 2: Invalid Memory Access (IMA) in Triton Kernels ------------------------------------------------------ ### Background The Triton kernel for flex attention was experiencing invalid memory access errors when running with compute sanitizers, particularly with short KV sequences and small batch sizes. ### Root Cause * Kernel launches CTAs (Cooperative Thread Arrays) proportional to GPU's multi-processor count (108 via `SPLIT_KV`) * With small workloads, many CTAs remain idle but still attempt to access `kv_indices` with invalid `indices_idx` values * This caused out-of-bounds memory access violations ### Solution Implemented boundary checks with early exit: 1. **Added `MAX_VALID_KV_IDX` parameter** in `torch/_inductor/kernel/flex/flex_decoding.py` * Calculate maximum valid KV index based on actual `kv_indices` tensor size and pass it to Triton template 2. **Added early exit logic** in `torch/_inductor/kernel/flex/templates/flex_decode.py.jinja` * Boundary checks before accessing `kv_indices` in both normal and full blocks * Idle CTAs with invalid `indices_idx` skip computation entirely This prevents invalid memory access while reducing wasted computation on idle thread blocks. Testing & Validation -------------------- ### Accuracy Tests * Added comprehensive test cases covering KV sequences not divisible by block sizes * Verified output matches standard attention for various sequence length combinations ### Sanitizer Results `========= COMPUTE-SANITIZER Starting standalone test_max_autotune... Running test_max_autotune on device: cuda max_autotune config: True test_max_autotune completed successfully! Test passed! ========= ERROR SUMMARY: 0 errors` **Before**: More than 13720 invalid memory access errors with sanitizers **After**: Clean execution with 0 errors Both fixes work together to ensure paged attention produces accurate results while running safely without memory access violations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160861 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
parent
76f81b56d3
commit
2fed4fb464
|
|
@ -699,8 +699,13 @@ class TestFlexAttention(InductorTestCase):
|
|||
paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
|
||||
|
||||
# convert block mask and score mod
|
||||
converted_block_mask = paged_attention.convert_logical_block_mask(block_mask)
|
||||
converted_score_mod = paged_attention.get_score_mod(score_mod)
|
||||
kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64)
|
||||
converted_block_mask = paged_attention.convert_logical_block_mask(
|
||||
block_mask, kv_len=kv_len_tensor
|
||||
)
|
||||
converted_score_mod = paged_attention.get_score_mod(
|
||||
score_mod, kv_len=kv_len_tensor
|
||||
)
|
||||
return k_cache, v_cache, converted_block_mask, converted_score_mod
|
||||
|
||||
def run_paged_attention(
|
||||
|
|
@ -2449,6 +2454,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
self.run_test_with_paged_attention(
|
||||
score_mod, dtype=torch.float16, device=device
|
||||
)
|
||||
self.run_test_with_paged_attention(
|
||||
score_mod=score_mod,
|
||||
dtype=torch.bfloat16,
|
||||
KV_S=64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@skip("TODO: Figure out why this is erroring")
|
||||
|
|
@ -5204,7 +5215,12 @@ class TestPagedAttention(InductorTestCase):
|
|||
block_mask = create_block_mask(
|
||||
causal_mask, max_batch_size, 1, max_seq_len, max_seq_len, device=device
|
||||
)
|
||||
new_block_mask = paged_cache.convert_logical_block_mask(block_mask)
|
||||
kv_len_tensor = torch.full(
|
||||
(max_batch_size,), max_seq_len, device=device, dtype=torch.int64
|
||||
)
|
||||
new_block_mask = paged_cache.convert_logical_block_mask(
|
||||
block_mask, kv_len=kv_len_tensor
|
||||
)
|
||||
|
||||
zeros = [0, 0, 0, 0]
|
||||
# Check that the new block mask is correct
|
||||
|
|
@ -5480,11 +5496,18 @@ class TestPagedAttention(InductorTestCase):
|
|||
)
|
||||
paged_cache.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
|
||||
|
||||
new_block_mask = paged_cache.convert_logical_block_mask(block_mask)
|
||||
kv_len_tensor = torch.full(
|
||||
(max_batch_size,), max_seq_len, device=device, dtype=torch.int64
|
||||
)
|
||||
new_block_mask = paged_cache.convert_logical_block_mask(
|
||||
block_mask, kv_len=kv_len_tensor
|
||||
)
|
||||
|
||||
compiled_sdpa = torch.compile(
|
||||
create_attention(
|
||||
paged_cache.get_score_mod(score_mod), block_mask, enable_gqa=False
|
||||
paged_cache.get_score_mod(score_mod, kv_len=kv_len_tensor),
|
||||
block_mask,
|
||||
enable_gqa=False,
|
||||
)
|
||||
)
|
||||
paged_out = compiled_sdpa(q, k_cache, v_cache, block_mask=new_block_mask)
|
||||
|
|
|
|||
|
|
@ -556,8 +556,13 @@ class TestFlexDecoding(InductorTestCase):
|
|||
paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
|
||||
|
||||
# convert block mask and score mod
|
||||
converted_block_mask = paged_attention.convert_logical_block_mask(block_mask)
|
||||
converted_score_mod = paged_attention.get_score_mod(score_mod)
|
||||
kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64)
|
||||
converted_block_mask = paged_attention.convert_logical_block_mask(
|
||||
block_mask, kv_len=kv_len_tensor
|
||||
)
|
||||
converted_score_mod = paged_attention.get_score_mod(
|
||||
score_mod, kv_len=kv_len_tensor
|
||||
)
|
||||
|
||||
return k_cache, v_cache, converted_block_mask, converted_score_mod
|
||||
|
||||
|
|
@ -1548,6 +1553,19 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
|
||||
self.run_test(score_mod, device=device)
|
||||
self.run_test_with_paged_attention(score_mod, device=device)
|
||||
self.run_test_with_paged_attention(
|
||||
score_mod=score_mod,
|
||||
dtype=torch.bfloat16,
|
||||
Q_B=4,
|
||||
Q_H=1,
|
||||
Q_S=1,
|
||||
QK_D=16,
|
||||
KV_B=4,
|
||||
KV_H=1,
|
||||
KV_S=64,
|
||||
V_D=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@patch.object(torch._inductor.config, "max_autotune", True)
|
||||
|
|
@ -2016,11 +2034,18 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
input_pos = torch.tensor(prefill_length, device=device, dtype=torch.int32).view(
|
||||
max_batch_size, 1
|
||||
)
|
||||
new_block_mask = paged_cache.convert_logical_block_mask(block_mask)
|
||||
kv_len_tensor = torch.full(
|
||||
(max_batch_size,), max_seq_len, device=device, dtype=torch.int64
|
||||
)
|
||||
new_block_mask = paged_cache.convert_logical_block_mask(
|
||||
block_mask, kv_len=kv_len_tensor
|
||||
)
|
||||
new_block_mask.seq_lengths = (1, new_block_mask.seq_lengths[1])
|
||||
compiled_sdpa = torch.compile(
|
||||
create_attention(
|
||||
paged_cache.get_score_mod(score_mod), new_block_mask, enable_gqa=False
|
||||
paged_cache.get_score_mod(score_mod, kv_len=kv_len_tensor),
|
||||
new_block_mask,
|
||||
enable_gqa=False,
|
||||
)
|
||||
)
|
||||
paged_out = compiled_sdpa(
|
||||
|
|
|
|||
|
|
@ -792,7 +792,7 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
return ""
|
||||
|
||||
if start_offset == -1:
|
||||
start_offset = getattr(self, len_attr)
|
||||
start_offset = self.len_score_other
|
||||
|
||||
length = getattr(self, len_attr)
|
||||
for i in range(length):
|
||||
|
|
@ -995,9 +995,9 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
value=value,
|
||||
kv_num_blocks=self.input_nodes[3],
|
||||
kv_indices=self.input_nodes[4],
|
||||
full_kv_num_blocks=self.input_nodes[5]
|
||||
if not self.no_full_kv_block
|
||||
else None,
|
||||
full_kv_num_blocks=(
|
||||
self.input_nodes[5] if not self.no_full_kv_block else None
|
||||
),
|
||||
full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None,
|
||||
score_mod_other_buffers=self.score_mod_other_buffers,
|
||||
mask_mod_other_buffers=self.mask_mod_other_buffers,
|
||||
|
|
|
|||
|
|
@ -354,6 +354,13 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
**cur_kernel_options,
|
||||
)
|
||||
|
||||
filtered_score_mod_buffers = [
|
||||
buf for buf in score_mod_other_buffers if not isinstance(buf, sympy.Symbol)
|
||||
]
|
||||
filtered_mask_mod_buffers = [
|
||||
buf for buf in mask_mod_other_buffers if not isinstance(buf, sympy.Symbol)
|
||||
]
|
||||
|
||||
inputs_for_flex_decoding = (
|
||||
[
|
||||
query,
|
||||
|
|
@ -366,8 +373,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
full_kv_num_blocks,
|
||||
full_kv_indices,
|
||||
]
|
||||
+ list(score_mod_other_buffers)
|
||||
+ list(mask_mod_other_buffers)
|
||||
+ filtered_score_mod_buffers
|
||||
+ filtered_mask_mod_buffers
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
|
|
|
|||
|
|
@ -120,7 +120,8 @@
|
|||
# Offset the kv_indices tensor by the correct batch and head
|
||||
kv_indices = KV_IDX + sparse_idx_hz_offset
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
MAX_KV_IDX = {{size("KV_IDX", -1)}}
|
||||
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
# first kv block we're loading
|
||||
|
|
@ -156,7 +157,7 @@
|
|||
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
||||
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
|
||||
|
|
@ -220,4 +221,4 @@
|
|||
|
||||
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
||||
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
||||
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,7 @@ class PagedAttention:
|
|||
self,
|
||||
block_mask: BlockMask,
|
||||
batch_idx: Optional[torch.Tensor] = None,
|
||||
kv_len: Optional[torch.Tensor] = None,
|
||||
) -> BlockMask:
|
||||
"""
|
||||
Converts a logical block mask by mapping its logical kv indices to the corresponding
|
||||
|
|
@ -210,6 +211,8 @@ class PagedAttention:
|
|||
batch dimension. This provides flexibility to convert a
|
||||
block mask with smaller batch size than the page table;
|
||||
shape :math:`(B)`.
|
||||
kv_len (Optional[Tensor]): actual KV sequence length for upper bound check;
|
||||
shape :math:`(B,)` to handle multiple batches.
|
||||
"""
|
||||
B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape
|
||||
|
||||
|
|
@ -261,7 +264,7 @@ class PagedAttention:
|
|||
.to(torch.int32)
|
||||
)
|
||||
|
||||
new_mask_mod = self.get_mask_mod(block_mask.mask_mod)
|
||||
new_mask_mod = self.get_mask_mod(block_mask.mask_mod, kv_len)
|
||||
|
||||
seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
|
||||
return BlockMask.from_kv_blocks(
|
||||
|
|
@ -275,7 +278,9 @@ class PagedAttention:
|
|||
)
|
||||
|
||||
def get_mask_mod(
|
||||
self, mask_mod: Optional[_mask_mod_signature]
|
||||
self,
|
||||
mask_mod: Optional[_mask_mod_signature],
|
||||
kv_len: Optional[torch.Tensor] = None,
|
||||
) -> _mask_mod_signature:
|
||||
"""
|
||||
Converts a mask_mod based on mapping from the physical block index to the logical
|
||||
|
|
@ -283,6 +288,7 @@ class PagedAttention:
|
|||
|
||||
Args:
|
||||
mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
|
||||
kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check.
|
||||
"""
|
||||
if mask_mod is None:
|
||||
mask_mod = noop_mask
|
||||
|
|
@ -297,14 +303,21 @@ class PagedAttention:
|
|||
physical_kv_offset = physical_kv_idx % self.page_size
|
||||
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
|
||||
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
|
||||
return torch.where(
|
||||
logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
|
||||
live_block = logical_block_idx >= 0
|
||||
within_upper_bound = (
|
||||
logical_kv_idx < kv_len[b] if kv_len is not None else True
|
||||
)
|
||||
within_lower_bound = logical_kv_idx >= 0
|
||||
is_valid = live_block & within_upper_bound & within_lower_bound
|
||||
|
||||
return torch.where(is_valid, mask_mod(b, h, q_idx, logical_kv_idx), False)
|
||||
|
||||
return new_mask_mod
|
||||
|
||||
def get_score_mod(
|
||||
self, score_mod: Optional[_score_mod_signature]
|
||||
self,
|
||||
score_mod: Optional[_score_mod_signature],
|
||||
kv_len: Optional[torch.Tensor] = None,
|
||||
) -> _score_mod_signature:
|
||||
"""
|
||||
Converts a score_mod based on mapping from the physical block index to the logical
|
||||
|
|
@ -312,6 +325,8 @@ class PagedAttention:
|
|||
|
||||
Args:
|
||||
score_mod (_score_mod_signature): score_mod based on the logical block index.
|
||||
`kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check.
|
||||
|
||||
"""
|
||||
if score_mod is None:
|
||||
score_mod = _identity
|
||||
|
|
@ -327,8 +342,15 @@ class PagedAttention:
|
|||
physical_kv_offset = physical_kv_idx % self.page_size
|
||||
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
|
||||
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
|
||||
live_block = logical_block_idx >= 0
|
||||
within_upper_bound = (
|
||||
logical_kv_idx < kv_len[b] if kv_len is not None else True
|
||||
)
|
||||
within_lower_bound = logical_kv_idx >= 0
|
||||
is_valid = live_block & within_upper_bound & within_lower_bound
|
||||
|
||||
return torch.where(
|
||||
logical_block_idx >= 0,
|
||||
is_valid,
|
||||
score_mod(score, b, h, q_idx, logical_kv_idx),
|
||||
float("-inf"),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user