[pytorch][triton] flex attention fwd kernel with TMA loads (#151923) (#152460)

Summary:

Device side TMA for flex_attention fwd kernel, Q K V tensors

Test Plan:
Unit test:
```
buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:flex_attention -- test_tma_with_customer_kernel_options
```
https://www.internalfb.com/intern/testinfra/testrun/14355223891618726

Differential Revision: D71082691

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152460
Approved by: https://github.com/drisspg
This commit is contained in:
Mandar Deshpande 2025-05-15 04:49:26 +00:00 committed by PyTorch MergeBot
parent 756fd80734
commit 3e8bda4ad5
3 changed files with 208 additions and 57 deletions

View File

@ -45,7 +45,7 @@ from torch.testing._internal.common_device_type import (
skipCPUIf, skipCPUIf,
skipCUDAIf, skipCUDAIf,
) )
from torch.utils._triton import has_triton from torch.utils._triton import has_triton, has_triton_tma_device
# Use this decorator only when hitting Triton bugs on H100 # Use this decorator only when hitting Triton bugs on H100
@ -3908,6 +3908,34 @@ class GraphModule(torch.nn.Module):
C1, C2, atol=1e-2, rtol=1e-2 C1, C2, atol=1e-2, rtol=1e-2
), "Warp specialized kernel result differs from reference" ), "Warp specialized kernel result differs from reference"
@supported_platform
@skip_on_cpu
@skipCUDAIf(not has_triton_tma_device(), "Requires TMA enabled CUDA device")
def test_tma_with_customer_kernel_options(self):
make_tensor = functools.partial(
torch.ones,
(1, 1, 256, 128),
device="cuda",
dtype=torch.bfloat16,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
kernel_options_1 = {
"BLOCK_M": 128,
"BLOCK_N": 128,
"USE_TMA": False,
}
kernel_options_2 = {"BLOCK_M": 128, "BLOCK_N": 128, "USE_TMA": True}
flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True)
out_compiled = flex_compile(query, key, value, kernel_options=kernel_options_1)
out_tma_compiled = flex_compile(
query, key, value, kernel_options=kernel_options_2
)
# vanilla compiled vs TMA compiled
torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1)
class TestBlockMask(InductorTestCase): class TestBlockMask(InductorTestCase):
def setUp(self): def setUp(self):

View File

@ -51,6 +51,7 @@ from ..select_algorithm import (
SymbolicGridFn, SymbolicGridFn,
TritonTemplate, TritonTemplate,
) )
from ..utils import get_tma_workspace_arg
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -388,6 +389,48 @@ compute_flex_attention = r"""
off_zq = tl.program_id(1) // HQ off_zq = tl.program_id(1) // HQ
off_hq = tl.program_id(1) % HQ off_hq = tl.program_id(1) % HQ
# Setting up the TMA descriptors for Q, K, V
desc_q = None
desc_k = None
desc_v = None
if USE_TMA:
TMA_SIZE = 128
workspace_base = ws_ptr + TMA_SIZE * 3 * (
tl.program_id(1) + tl.program_id(0) * tl.num_programs(1)
)
desc_q = workspace_base
desc_v = workspace_base + TMA_SIZE
desc_k = workspace_base + 2 * TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_q,
global_address=Q,
load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
element_ty=Q.dtype.element_ty,
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_v,
global_address=V,
load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED],
global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
element_ty=K.dtype.element_ty,
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_k,
global_address=K,
load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM],
element_ty=K.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k)
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
off_zkv = off_zq % ZKV off_zkv = off_zq % ZKV
@ -426,16 +469,30 @@ compute_flex_attention = r"""
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
K_block_ptr = None
V_block_ptr = None
Q_block_ptr = None
if not USE_TMA:
Q_block_ptr = tl.make_block_ptr( Q_block_ptr = tl.make_block_ptr(
base=Q, base=Q ,
shape=(Q_LEN, QK_HEAD_DIM), shape=(Q_LEN, QK_HEAD_DIM),
strides=(stride_qm, stride_qk), strides=(stride_qm, stride_qk),
offsets=(q_start * BLOCK_M, 0), offsets=(q_start * BLOCK_M, 0),
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
order=(1, 0) order=(1, 0)
) )
if USE_TMA:
q = tl._experimental_descriptor_load( # load in row major
desc_q,
[(q_start * BLOCK_M).to(tl.int32), 0],
[BLOCK_M, QK_HEAD_DIM_ROUNDED],
Q.dtype.element_ty,
)
else:
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply # We don't know anything "special" about these blocks, so we need to apply
# both score_mod and mask_mod to it # both score_mod and mask_mod to it
@ -444,6 +501,8 @@ compute_flex_attention = r"""
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
if not USE_TMA:
K_block_ptr = tl.make_block_ptr( K_block_ptr = tl.make_block_ptr(
base=K, base=K,
shape=(QK_HEAD_DIM, KV_LEN), shape=(QK_HEAD_DIM, KV_LEN),
@ -452,6 +511,7 @@ compute_flex_attention = r"""
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1) order=(0, 1)
) )
V_block_ptr = tl.make_block_ptr( V_block_ptr = tl.make_block_ptr(
base=V, base=V,
shape=(KV_LEN, V_HEAD_DIM), shape=(KV_LEN, V_HEAD_DIM),
@ -460,13 +520,17 @@ compute_flex_attention = r"""
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0) order=(1, 0)
) )
offs_n = kv_start + tl.arange(0, BLOCK_N) offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner( acc, l_i, m_i = forward_inner(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i, acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :], off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks, kv_indices, kv_num_blocks,
0, block_n_end, 0, block_n_end,
MATMUL_PRECISION, MATMUL_PRECISION,
@ -482,7 +546,7 @@ compute_flex_attention = r"""
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
if not USE_TMA:
K_block_ptr = tl.make_block_ptr( K_block_ptr = tl.make_block_ptr(
base=K, base=K,
shape=(QK_HEAD_DIM, KV_LEN), shape=(QK_HEAD_DIM, KV_LEN),
@ -503,9 +567,11 @@ compute_flex_attention = r"""
acc, l_i, m_i = forward_inner( acc, l_i, m_i = forward_inner(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i, acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :], off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks, kv_indices, kv_num_blocks,
0, block_n_end, 0, block_n_end,
MATMUL_PRECISION, MATMUL_PRECISION,
@ -543,12 +609,15 @@ compute_forward_inner = r"""
@triton.jit @triton.jit
def forward_inner( def forward_inner(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values # accumulated values
acc, l_i, m_i, acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod # Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar. # of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n, off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
# blocksparse data # blocksparse data
kv_indices, kv_num_blocks, kv_indices, kv_num_blocks,
# start kv and end kv block # start kv and end kv block
@ -567,14 +636,18 @@ def forward_inner(
# loop over k, v and update accumulator until block_n_end # loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end): for start_n in range(block_n_start, block_n_end):
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
if IS_DIVISIBLE: if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn( acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values # accumulated values
acc, l_i, m_i, acc, l_i, m_i,
# Offsets # Offsets
off_z, off_h, offs_m, offs_n, off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
start_n,
MATMUL_PRECISION, RCP_LN2, MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, IS_FULL_BLOCKS,
) )
@ -585,25 +658,30 @@ def forward_inner(
# to the last block because it's faster a lot. # to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn( acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values # accumulated values
acc, l_i, m_i, acc, l_i, m_i,
# Offsets # Offsets
off_z, off_h, offs_m, offs_n, off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
start_n,
MATMUL_PRECISION, RCP_LN2, MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
) )
# update pointers
offset = get_offset_for_next_block( offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks, start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
) )
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
offs_n = offs_n + offset offs_n = offs_n + offset
if not USE_TMA:
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
return acc, l_i, m_i return acc, l_i, m_i
@ -614,11 +692,14 @@ compute_forward_block_mn = r"""
@triton.jit @triton.jit
def forward_block_mn( def forward_block_mn(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values # accumulated values
acc, l_i, m_i, acc, l_i, m_i,
# Offsets # Offsets
off_z, off_h, offs_m, offs_n, off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
start_n,
MATMUL_PRECISION, RCP_LN2, MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
@ -628,7 +709,18 @@ def forward_block_mn(
# -- load k -- # -- load k --
# NB reversed order to since K is transposed # NB reversed order to since K is transposed
if USE_TMA:
k = tl._experimental_descriptor_load( # load in row major
desc_k,
[start_n.to(tl.int32) , kv_start],
[BLOCK_N, QK_HEAD_DIM_ROUNDED],
MATMUL_PRECISION,
)
else:
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
if USE_TMA:
k = tl.trans(k)
# -- compute qk --- # -- compute qk ---
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK: if not PRESCALE_QK:
@ -692,6 +784,14 @@ def forward_block_mn(
l_i = l_i * alpha + tl.sum(p, 1) l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc -- # # -- scale and update acc --
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
if USE_TMA:
v = tl._experimental_descriptor_load( # load in row major
desc_v,
[kv_start.to(tl.int32) + start_n.to(tl.int32),0],
[BLOCK_N, V_HEAD_DIM_ROUNDED],
MATMUL_PRECISION,
)
else:
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
@ -1542,12 +1642,29 @@ def flex_attention(
"num_buffers_warp_spec", num_buffers_warp_spec "num_buffers_warp_spec", num_buffers_warp_spec
) )
# Disabling TMA by default, only explicit kernel_options supported for now
cur_kernel_options.setdefault("USE_TMA", False)
cur_kernel_options.setdefault("BLOCK_M", BLOCK_M) cur_kernel_options.setdefault("BLOCK_M", BLOCK_M)
cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) cur_kernel_options.setdefault("BLOCK_N", BLOCK_N)
# Blocksparse options # Blocksparse options
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
workspace_arg = None
if cur_kernel_options.get("USE_TMA", False):
seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q)
grid = flex_attention_grid(
Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options
)
num_programs = grid[0] * grid[1] * grid[2]
workspace_arg = get_tma_workspace_arg(
num_tma_descriptors=3,
device=query.get_device(),
num_programs=num_programs,
)
error = flex_attention_template.maybe_append_choice( error = flex_attention_template.maybe_append_choice(
choices=choices, choices=choices,
input_nodes=[ input_nodes=[
@ -1568,6 +1685,7 @@ def flex_attention(
mutated_inputs=[ mutated_inputs=[
logsumexp, logsumexp,
], ],
workspace_arg=workspace_arg,
call_sizes=query.get_size(), call_sizes=query.get_size(),
**cur_kernel_options, **cur_kernel_options,
) )

View File

@ -196,11 +196,12 @@ flex_decoding_template = TritonTemplate(
acc, l_i, m_i = forward_inner( acc, l_i, m_i = forward_inner(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
# accumulatd values # accumulatd values
acc, l_i, m_i, acc, l_i, m_i,
#offsets #offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
None,
#block sparse data #block sparse data
kv_indices, kv_num_blocks, kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
@ -245,11 +246,12 @@ flex_decoding_template = TritonTemplate(
acc, l_i, m_i = forward_inner( acc, l_i, m_i = forward_inner(
{{gen_argdefs()}}, {{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
# accumulatd values # accumulatd values
acc, l_i, m_i, acc, l_i, m_i,
#offsets #offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
None,
#block sparse data #block sparse data
kv_indices, kv_num_blocks, kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
@ -551,6 +553,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
"num_buffers_warp_spec", num_buffers_warp_spec "num_buffers_warp_spec", num_buffers_warp_spec
) )
# Set default to False
cur_kernel_options.setdefault("USE_TMA", False)
flex_decoding_template.maybe_append_choice( flex_decoding_template.maybe_append_choice(
choices=choices, choices=choices,
input_nodes=[ input_nodes=[