mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Device side TMA for flex_attention fwd kernel, Q K V tensors Test Plan: Unit test: ``` buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:flex_attention -- test_tma_with_customer_kernel_options ``` https://www.internalfb.com/intern/testinfra/testrun/14355223891618726 Differential Revision: D71082691 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152460 Approved by: https://github.com/drisspg
This commit is contained in:
parent
756fd80734
commit
3e8bda4ad5
|
|
@ -45,7 +45,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
skipCPUIf,
|
skipCPUIf,
|
||||||
skipCUDAIf,
|
skipCUDAIf,
|
||||||
)
|
)
|
||||||
from torch.utils._triton import has_triton
|
from torch.utils._triton import has_triton, has_triton_tma_device
|
||||||
|
|
||||||
|
|
||||||
# Use this decorator only when hitting Triton bugs on H100
|
# Use this decorator only when hitting Triton bugs on H100
|
||||||
|
|
@ -3908,6 +3908,34 @@ class GraphModule(torch.nn.Module):
|
||||||
C1, C2, atol=1e-2, rtol=1e-2
|
C1, C2, atol=1e-2, rtol=1e-2
|
||||||
), "Warp specialized kernel result differs from reference"
|
), "Warp specialized kernel result differs from reference"
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@skip_on_cpu
|
||||||
|
@skipCUDAIf(not has_triton_tma_device(), "Requires TMA enabled CUDA device")
|
||||||
|
def test_tma_with_customer_kernel_options(self):
|
||||||
|
make_tensor = functools.partial(
|
||||||
|
torch.ones,
|
||||||
|
(1, 1, 256, 128),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
query, key, value = make_tensor(), make_tensor(), make_tensor()
|
||||||
|
|
||||||
|
kernel_options_1 = {
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 128,
|
||||||
|
"USE_TMA": False,
|
||||||
|
}
|
||||||
|
kernel_options_2 = {"BLOCK_M": 128, "BLOCK_N": 128, "USE_TMA": True}
|
||||||
|
|
||||||
|
flex_compile = torch.compile(flex_attention, fullgraph=True, dynamic=True)
|
||||||
|
out_compiled = flex_compile(query, key, value, kernel_options=kernel_options_1)
|
||||||
|
out_tma_compiled = flex_compile(
|
||||||
|
query, key, value, kernel_options=kernel_options_2
|
||||||
|
)
|
||||||
|
|
||||||
|
# vanilla compiled vs TMA compiled
|
||||||
|
torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1)
|
||||||
|
|
||||||
|
|
||||||
class TestBlockMask(InductorTestCase):
|
class TestBlockMask(InductorTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ from ..select_algorithm import (
|
||||||
SymbolicGridFn,
|
SymbolicGridFn,
|
||||||
TritonTemplate,
|
TritonTemplate,
|
||||||
)
|
)
|
||||||
|
from ..utils import get_tma_workspace_arg
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -388,6 +389,48 @@ compute_flex_attention = r"""
|
||||||
off_zq = tl.program_id(1) // HQ
|
off_zq = tl.program_id(1) // HQ
|
||||||
off_hq = tl.program_id(1) % HQ
|
off_hq = tl.program_id(1) % HQ
|
||||||
|
|
||||||
|
|
||||||
|
# Setting up the TMA descriptors for Q, K, V
|
||||||
|
desc_q = None
|
||||||
|
desc_k = None
|
||||||
|
desc_v = None
|
||||||
|
if USE_TMA:
|
||||||
|
TMA_SIZE = 128
|
||||||
|
workspace_base = ws_ptr + TMA_SIZE * 3 * (
|
||||||
|
tl.program_id(1) + tl.program_id(0) * tl.num_programs(1)
|
||||||
|
)
|
||||||
|
desc_q = workspace_base
|
||||||
|
desc_v = workspace_base + TMA_SIZE
|
||||||
|
desc_k = workspace_base + 2 * TMA_SIZE
|
||||||
|
|
||||||
|
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||||
|
desc_ptr=desc_q,
|
||||||
|
global_address=Q,
|
||||||
|
load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||||
|
global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
|
||||||
|
element_ty=Q.dtype.element_ty,
|
||||||
|
)
|
||||||
|
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||||
|
desc_ptr=desc_v,
|
||||||
|
global_address=V,
|
||||||
|
load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||||
|
global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
|
||||||
|
element_ty=K.dtype.element_ty,
|
||||||
|
)
|
||||||
|
|
||||||
|
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||||
|
desc_ptr=desc_k,
|
||||||
|
global_address=K,
|
||||||
|
load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||||
|
global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM],
|
||||||
|
element_ty=K.dtype.element_ty,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q)
|
||||||
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k)
|
||||||
|
|
||||||
|
|
||||||
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
||||||
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
||||||
off_zkv = off_zq % ZKV
|
off_zkv = off_zq % ZKV
|
||||||
|
|
@ -426,16 +469,30 @@ compute_flex_attention = r"""
|
||||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
||||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
||||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
||||||
|
K_block_ptr = None
|
||||||
|
V_block_ptr = None
|
||||||
|
Q_block_ptr = None
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
Q_block_ptr = tl.make_block_ptr(
|
Q_block_ptr = tl.make_block_ptr(
|
||||||
base=Q,
|
base=Q ,
|
||||||
shape=(Q_LEN, QK_HEAD_DIM),
|
shape=(Q_LEN, QK_HEAD_DIM),
|
||||||
strides=(stride_qm, stride_qk),
|
strides=(stride_qm, stride_qk),
|
||||||
offsets=(q_start * BLOCK_M, 0),
|
offsets=(q_start * BLOCK_M, 0),
|
||||||
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
|
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
|
||||||
order=(1, 0)
|
order=(1, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if USE_TMA:
|
||||||
|
q = tl._experimental_descriptor_load( # load in row major
|
||||||
|
desc_q,
|
||||||
|
[(q_start * BLOCK_M).to(tl.int32), 0],
|
||||||
|
[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||||
|
Q.dtype.element_ty,
|
||||||
|
)
|
||||||
|
else:
|
||||||
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# We don't know anything "special" about these blocks, so we need to apply
|
# We don't know anything "special" about these blocks, so we need to apply
|
||||||
# both score_mod and mask_mod to it
|
# both score_mod and mask_mod to it
|
||||||
|
|
@ -444,6 +501,8 @@ compute_flex_attention = r"""
|
||||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||||
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||||
|
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
K_block_ptr = tl.make_block_ptr(
|
K_block_ptr = tl.make_block_ptr(
|
||||||
base=K,
|
base=K,
|
||||||
shape=(QK_HEAD_DIM, KV_LEN),
|
shape=(QK_HEAD_DIM, KV_LEN),
|
||||||
|
|
@ -452,6 +511,7 @@ compute_flex_attention = r"""
|
||||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||||
order=(0, 1)
|
order=(0, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
V_block_ptr = tl.make_block_ptr(
|
V_block_ptr = tl.make_block_ptr(
|
||||||
base=V,
|
base=V,
|
||||||
shape=(KV_LEN, V_HEAD_DIM),
|
shape=(KV_LEN, V_HEAD_DIM),
|
||||||
|
|
@ -460,13 +520,17 @@ compute_flex_attention = r"""
|
||||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||||
order=(1, 0)
|
order=(1, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
|
||||||
acc, l_i, m_i = forward_inner(
|
acc, l_i, m_i = forward_inner(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr,
|
||||||
|
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
||||||
|
kv_start,
|
||||||
kv_indices, kv_num_blocks,
|
kv_indices, kv_num_blocks,
|
||||||
0, block_n_end,
|
0, block_n_end,
|
||||||
MATMUL_PRECISION,
|
MATMUL_PRECISION,
|
||||||
|
|
@ -482,7 +546,7 @@ compute_flex_attention = r"""
|
||||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||||
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||||
|
if not USE_TMA:
|
||||||
K_block_ptr = tl.make_block_ptr(
|
K_block_ptr = tl.make_block_ptr(
|
||||||
base=K,
|
base=K,
|
||||||
shape=(QK_HEAD_DIM, KV_LEN),
|
shape=(QK_HEAD_DIM, KV_LEN),
|
||||||
|
|
@ -503,9 +567,11 @@ compute_flex_attention = r"""
|
||||||
|
|
||||||
acc, l_i, m_i = forward_inner(
|
acc, l_i, m_i = forward_inner(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr,
|
||||||
|
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
||||||
|
kv_start,
|
||||||
kv_indices, kv_num_blocks,
|
kv_indices, kv_num_blocks,
|
||||||
0, block_n_end,
|
0, block_n_end,
|
||||||
MATMUL_PRECISION,
|
MATMUL_PRECISION,
|
||||||
|
|
@ -543,12 +609,15 @@ compute_forward_inner = r"""
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def forward_inner(
|
def forward_inner(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr,
|
||||||
|
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||||
# accumulated values
|
# accumulated values
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
# Offsets used as inputs to score_mod & mask_mod
|
# Offsets used as inputs to score_mod & mask_mod
|
||||||
# of size [BLOCK_M, BLOCK_N] or scalar.
|
# of size [BLOCK_M, BLOCK_N] or scalar.
|
||||||
off_z, off_h, offs_m, offs_n,
|
off_z, off_h, offs_m, offs_n,
|
||||||
|
# Offsets needed for TMA loads
|
||||||
|
kv_start,
|
||||||
# blocksparse data
|
# blocksparse data
|
||||||
kv_indices, kv_num_blocks,
|
kv_indices, kv_num_blocks,
|
||||||
# start kv and end kv block
|
# start kv and end kv block
|
||||||
|
|
@ -567,14 +636,18 @@ def forward_inner(
|
||||||
|
|
||||||
# loop over k, v and update accumulator until block_n_end
|
# loop over k, v and update accumulator until block_n_end
|
||||||
for start_n in range(block_n_start, block_n_end):
|
for start_n in range(block_n_start, block_n_end):
|
||||||
|
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
||||||
if IS_DIVISIBLE:
|
if IS_DIVISIBLE:
|
||||||
acc, l_i, m_i = forward_block_mn(
|
acc, l_i, m_i = forward_block_mn(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||||
# accumulated values
|
# accumulated values
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
# Offsets
|
# Offsets
|
||||||
off_z, off_h, offs_m, offs_n,
|
off_z, off_h, offs_m, offs_n,
|
||||||
|
# Offsets needed for TMA loads
|
||||||
|
kv_start,
|
||||||
|
start_n,
|
||||||
MATMUL_PRECISION, RCP_LN2,
|
MATMUL_PRECISION, RCP_LN2,
|
||||||
IS_FULL_BLOCKS,
|
IS_FULL_BLOCKS,
|
||||||
)
|
)
|
||||||
|
|
@ -585,25 +658,30 @@ def forward_inner(
|
||||||
# to the last block because it's faster a lot.
|
# to the last block because it's faster a lot.
|
||||||
acc, l_i, m_i = forward_block_mn(
|
acc, l_i, m_i = forward_block_mn(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||||
# accumulated values
|
# accumulated values
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
# Offsets
|
# Offsets
|
||||||
off_z, off_h, offs_m, offs_n,
|
off_z, off_h, offs_m, offs_n,
|
||||||
|
# Offsets needed for TMA loads
|
||||||
|
kv_start,
|
||||||
|
start_n,
|
||||||
MATMUL_PRECISION, RCP_LN2,
|
MATMUL_PRECISION, RCP_LN2,
|
||||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# update pointers
|
|
||||||
|
|
||||||
offset = get_offset_for_next_block(
|
offset = get_offset_for_next_block(
|
||||||
start_n, kv_indices, kv_num_blocks,
|
start_n, kv_indices, kv_num_blocks,
|
||||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
||||||
)
|
)
|
||||||
|
|
||||||
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
|
||||||
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
|
||||||
|
|
||||||
offs_n = offs_n + offset
|
offs_n = offs_n + offset
|
||||||
|
if not USE_TMA:
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
||||||
|
|
||||||
|
|
||||||
return acc, l_i, m_i
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
@ -614,11 +692,14 @@ compute_forward_block_mn = r"""
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def forward_block_mn(
|
def forward_block_mn(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||||
# accumulated values
|
# accumulated values
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
# Offsets
|
# Offsets
|
||||||
off_z, off_h, offs_m, offs_n,
|
off_z, off_h, offs_m, offs_n,
|
||||||
|
# Offsets needed for TMA loads
|
||||||
|
kv_start,
|
||||||
|
start_n,
|
||||||
MATMUL_PRECISION, RCP_LN2,
|
MATMUL_PRECISION, RCP_LN2,
|
||||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||||
|
|
||||||
|
|
@ -628,7 +709,18 @@ def forward_block_mn(
|
||||||
|
|
||||||
# -- load k --
|
# -- load k --
|
||||||
# NB reversed order to since K is transposed
|
# NB reversed order to since K is transposed
|
||||||
|
if USE_TMA:
|
||||||
|
k = tl._experimental_descriptor_load( # load in row major
|
||||||
|
desc_k,
|
||||||
|
[start_n.to(tl.int32) , kv_start],
|
||||||
|
[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||||
|
MATMUL_PRECISION,
|
||||||
|
)
|
||||||
|
else:
|
||||||
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
|
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
|
||||||
|
|
||||||
|
if USE_TMA:
|
||||||
|
k = tl.trans(k)
|
||||||
# -- compute qk ---
|
# -- compute qk ---
|
||||||
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
||||||
if not PRESCALE_QK:
|
if not PRESCALE_QK:
|
||||||
|
|
@ -692,6 +784,14 @@ def forward_block_mn(
|
||||||
l_i = l_i * alpha + tl.sum(p, 1)
|
l_i = l_i * alpha + tl.sum(p, 1)
|
||||||
# # -- scale and update acc --
|
# # -- scale and update acc --
|
||||||
acc = acc * alpha[:, None]
|
acc = acc * alpha[:, None]
|
||||||
|
if USE_TMA:
|
||||||
|
v = tl._experimental_descriptor_load( # load in row major
|
||||||
|
desc_v,
|
||||||
|
[kv_start.to(tl.int32) + start_n.to(tl.int32),0],
|
||||||
|
[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||||
|
MATMUL_PRECISION,
|
||||||
|
)
|
||||||
|
else:
|
||||||
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||||
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
||||||
|
|
||||||
|
|
@ -1542,12 +1642,29 @@ def flex_attention(
|
||||||
"num_buffers_warp_spec", num_buffers_warp_spec
|
"num_buffers_warp_spec", num_buffers_warp_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Disabling TMA by default, only explicit kernel_options supported for now
|
||||||
|
cur_kernel_options.setdefault("USE_TMA", False)
|
||||||
|
|
||||||
cur_kernel_options.setdefault("BLOCK_M", BLOCK_M)
|
cur_kernel_options.setdefault("BLOCK_M", BLOCK_M)
|
||||||
cur_kernel_options.setdefault("BLOCK_N", BLOCK_N)
|
cur_kernel_options.setdefault("BLOCK_N", BLOCK_N)
|
||||||
# Blocksparse options
|
# Blocksparse options
|
||||||
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
||||||
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
||||||
|
|
||||||
|
workspace_arg = None
|
||||||
|
if cur_kernel_options.get("USE_TMA", False):
|
||||||
|
seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q)
|
||||||
|
|
||||||
|
grid = flex_attention_grid(
|
||||||
|
Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options
|
||||||
|
)
|
||||||
|
|
||||||
|
num_programs = grid[0] * grid[1] * grid[2]
|
||||||
|
workspace_arg = get_tma_workspace_arg(
|
||||||
|
num_tma_descriptors=3,
|
||||||
|
device=query.get_device(),
|
||||||
|
num_programs=num_programs,
|
||||||
|
)
|
||||||
error = flex_attention_template.maybe_append_choice(
|
error = flex_attention_template.maybe_append_choice(
|
||||||
choices=choices,
|
choices=choices,
|
||||||
input_nodes=[
|
input_nodes=[
|
||||||
|
|
@ -1568,6 +1685,7 @@ def flex_attention(
|
||||||
mutated_inputs=[
|
mutated_inputs=[
|
||||||
logsumexp,
|
logsumexp,
|
||||||
],
|
],
|
||||||
|
workspace_arg=workspace_arg,
|
||||||
call_sizes=query.get_size(),
|
call_sizes=query.get_size(),
|
||||||
**cur_kernel_options,
|
**cur_kernel_options,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -196,11 +196,12 @@ flex_decoding_template = TritonTemplate(
|
||||||
|
|
||||||
acc, l_i, m_i = forward_inner(
|
acc, l_i, m_i = forward_inner(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
|
||||||
# accumulatd values
|
# accumulatd values
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
#offsets
|
#offsets
|
||||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||||
|
None,
|
||||||
#block sparse data
|
#block sparse data
|
||||||
kv_indices, kv_num_blocks,
|
kv_indices, kv_num_blocks,
|
||||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||||
|
|
@ -245,11 +246,12 @@ flex_decoding_template = TritonTemplate(
|
||||||
|
|
||||||
acc, l_i, m_i = forward_inner(
|
acc, l_i, m_i = forward_inner(
|
||||||
{{gen_argdefs()}},
|
{{gen_argdefs()}},
|
||||||
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
|
||||||
# accumulatd values
|
# accumulatd values
|
||||||
acc, l_i, m_i,
|
acc, l_i, m_i,
|
||||||
#offsets
|
#offsets
|
||||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||||
|
None,
|
||||||
#block sparse data
|
#block sparse data
|
||||||
kv_indices, kv_num_blocks,
|
kv_indices, kv_num_blocks,
|
||||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||||
|
|
@ -551,6 +553,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
||||||
"num_buffers_warp_spec", num_buffers_warp_spec
|
"num_buffers_warp_spec", num_buffers_warp_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set default to False
|
||||||
|
cur_kernel_options.setdefault("USE_TMA", False)
|
||||||
|
|
||||||
flex_decoding_template.maybe_append_choice(
|
flex_decoding_template.maybe_append_choice(
|
||||||
choices=choices,
|
choices=choices,
|
||||||
input_nodes=[
|
input_nodes=[
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user