[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support (#155182)

Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

Differential Revision: [D76318784](https://our.internmc.facebook.com/intern/diff/D76318784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155182
Approved by: https://github.com/drisspg
This commit is contained in:
David Berard 2025-06-09 16:23:22 -07:00 committed by PyTorch MergeBot
parent 8153340d10
commit b07725a951
4 changed files with 58 additions and 69 deletions

View File

@ -399,34 +399,28 @@ compute_flex_attention = r"""
workspace_base = ws_ptr + TMA_SIZE * 3 * (
tl.program_id(1) + tl.program_id(0) * tl.num_programs(1)
)
desc_q = workspace_base
desc_v = workspace_base + TMA_SIZE
desc_k = workspace_base + 2 * TMA_SIZE
desc_q_ptr = workspace_base
desc_v_ptr = workspace_base + TMA_SIZE
desc_k_ptr = workspace_base + 2 * TMA_SIZE
triton_helpers.make_tensor_descriptor(
desc_q = triton_helpers.make_tensor_descriptor(
base_ptr=Q,
global_shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
strides=[QK_HEAD_DIM, 1],
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
desc_ptr=desc_q,
element_ty=Q.dtype.element_ty,
desc_ptr=desc_q_ptr,
)
triton_helpers.make_tensor_descriptor(
desc_v = triton_helpers.make_tensor_descriptor(
base_ptr=V,
global_shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
strides=[V_HEAD_DIM, 1],
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
desc_ptr=desc_v,
element_ty=K.dtype.element_ty,
desc_ptr=desc_v_ptr,
)
triton_helpers.make_tensor_descriptor(
desc_k = triton_helpers.make_tensor_descriptor(
base_ptr=K,
global_shape=[KV_LEN*ZKV*HQ, QK_HEAD_DIM],
strides=[QK_HEAD_DIM, 1],
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
desc_ptr=desc_k,
element_ty=K.dtype.element_ty,
desc_ptr=desc_k_ptr,
)

View File

@ -265,23 +265,21 @@ persistent_tma_mm_template = TritonTemplate(
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
a_desc = triton_helpers.make_tensor_descriptor(
desc_ptr=a_desc_ptr,
global_address=A,
load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
global_size=[M, K] if A_ROW_MAJOR else [K, M],
element_ty=A.dtype.element_ty,
base_ptr=A,
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
global_shape=[M, K] if A_ROW_MAJOR else [K, M],
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
b_desc = triton_helpers.make_tensor_descriptor(
desc_ptr=b_desc_ptr,
global_address=B,
load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
global_size=[K, N] if B_ROW_MAJOR else [N, K],
element_ty=B.dtype.element_ty,
base_ptr=B,
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
global_shape=[K, N] if B_ROW_MAJOR else [N, K],
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
triton_helpers.tensormap_fenceproxy_acquire(a_desc)
triton_helpers.tensormap_fenceproxy_acquire(b_desc)
pid_m = 0
pid_n = 0
@ -303,14 +301,14 @@ persistent_tma_mm_template = TritonTemplate(
rk = ki * BLOCK_K
a = tl._experimental_descriptor_load(
a_desc_ptr,
a = triton_helpers.load_tensor_descriptor(
a_desc,
[rm, rk] if A_ROW_MAJOR else [rk, rm],
[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
A.dtype.element_ty,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
b = triton_helpers.load_tensor_descriptor(
b_desc,
[rk, rn] if B_ROW_MAJOR else [rn, rk],
[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
B.dtype.element_ty,
@ -416,23 +414,21 @@ device_tma = r"""
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
a_desc = triton_helpers.make_tensor_descriptor(
desc_ptr=a_desc_ptr,
global_address=A,
load_size=[BLOCK_M, BLOCK_K],
global_size=[M, K],
element_ty=A.dtype.element_ty,
base_ptr=A,
block_shape=[BLOCK_M, BLOCK_K],
global_shape=[M, K],
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
b_desc = triton_helpers.make_tensor_descriptor(
desc_ptr=b_desc_ptr,
global_address=B,
load_size=[BLOCK_N, BLOCK_K],
global_size=[N, K],
element_ty=B.dtype.element_ty,
base_ptr=B,
block_shape=[BLOCK_N, BLOCK_K],
global_shape=[N, K],
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
triton_helpers.tensormap_fenceproxy_acquire(a_desc)
triton_helpers.tensormap_fenceproxy_acquire(b_desc)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
@ -465,11 +461,11 @@ device_tma = r"""
offs_k = ki * BLOCK_K
a = tl._experimental_descriptor_load(
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
a = triton_helpers.load_tensor_descriptor(
a_desc, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
)
b = tl._experimental_descriptor_load(
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
b = triton_helpers.load_tensor_descriptor(
b_desc, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)

View File

@ -179,22 +179,20 @@ triton_scaled_grouped_mm_source = r"""
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_a = triton_helpers.make_tensor_descriptor(
desc_ptr=a_desc_ptr,
global_address=a_ptr,
load_size=[BLOCK_M, BLOCK_K],
global_size=[M, K],
element_ty=a_ptr.dtype.element_ty,
base_ptr=a_ptr,
block_shape=[BLOCK_M, BLOCK_K],
global_shape=[M, K],
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_b = triton_helpres.make_tensor_descriptor(
desc_ptr=b_desc_ptr,
global_address=b_ptr,
load_size=[BLOCK_N, BLOCK_K],
global_size=[N * G, K],
element_ty=b_ptr.dtype.element_ty,
base_ptr=b_ptr,
block_shape=[BLOCK_N, BLOCK_K],
global_shape=[N * G, K],
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
triton_helpers.tensormap_fenceproxy_acquire(desc_a)
triton_helpers.tensormap_fenceproxy_acquire(desc_b)
M_end_offset = 0
iterated_tiles = 0
@ -224,14 +222,14 @@ triton_scaled_grouped_mm_source = r"""
m_offset = (M_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_K):
a = tl._experimental_descriptor_load(
a_desc_ptr,
a = triton_helpers.load_tensor_descriptor(
a_desc,
[m_offset, k_offset],
[BLOCK_M, BLOCK_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
b = triton_helpers.load_tensor_descriptor(
b_desc,
[n_offset, k_offset],
[BLOCK_N, BLOCK_K],
dtype,

View File

@ -743,15 +743,18 @@ if HAS_NEW_TMA_API:
def make_tensor_descriptor(
base_ptr,
global_shape,
strides,
block_shape,
desc_ptr,
element_ty,
):
# note: this static assert isn't a fundamental limitation,
# inferring strides just isn't yet implemented for >2 dimensions.
tl.static_assert(
len(global_shape) == 2, "Only 2D tensors are supported by current API"
)
return tl.make_tensor_descriptor(
base=base_ptr,
shape=global_shape,
strides=strides,
strides=[global_shape[1], 1],
block_shape=block_shape,
)
@ -784,17 +787,15 @@ else:
def make_tensor_descriptor(
base_ptr,
global_shape,
strides,
block_shape,
desc_ptr,
element_ty,
):
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_ptr,
global_address=base_ptr,
load_size=block_shape,
global_size=global_shape,
element_ty=element_ty,
element_ty=base_ptr.dtype.element_ty,
)
return desc_ptr