mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FlexAttention] Fix weird generate stride call in flex decode (#147435)
# Summary Seems like we had a redundant tuple unpack and that doesn't appear to be supported in new triton Fixes https://github.com/pytorch/pytorch/issues/147373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147435 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
parent
77dbd28535
commit
303ad1916f
|
|
@ -137,7 +137,7 @@ flex_decoding_template = TritonTemplate(
|
|||
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
||||
stride_block_z, stride_block_h, stride_block_row, stride_block_col = {{stride("KV_NUM_BLKS")}}
|
||||
stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}}
|
||||
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
||||
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}}
|
||||
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user