[FlexAttention] Fix weird generate stride call in flex decode (#147435)

# Summary
Seems like we had a redundant tuple unpack and that doesn't appear to be supported in new triton

Fixes https://github.com/pytorch/pytorch/issues/147373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147435
Approved by: https://github.com/BoyuanFeng
This commit is contained in:
drisspg 2025-02-18 18:57:54 -08:00 committed by PyTorch MergeBot
parent 77dbd28535
commit 303ad1916f

View File

@ -137,7 +137,7 @@ flex_decoding_template = TritonTemplate(
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
stride_block_z, stride_block_h, stride_block_row, stride_block_col = {{stride("KV_NUM_BLKS")}}
stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}}
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}}
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h