mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support (#155182)
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim. Differential Revision: [D76318784](https://our.internmc.facebook.com/intern/diff/D76318784) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155182 Approved by: https://github.com/drisspg
This commit is contained in:
parent
8153340d10
commit
b07725a951
|
|
@ -399,34 +399,28 @@ compute_flex_attention = r"""
|
|||
workspace_base = ws_ptr + TMA_SIZE * 3 * (
|
||||
tl.program_id(1) + tl.program_id(0) * tl.num_programs(1)
|
||||
)
|
||||
desc_q = workspace_base
|
||||
desc_v = workspace_base + TMA_SIZE
|
||||
desc_k = workspace_base + 2 * TMA_SIZE
|
||||
desc_q_ptr = workspace_base
|
||||
desc_v_ptr = workspace_base + TMA_SIZE
|
||||
desc_k_ptr = workspace_base + 2 * TMA_SIZE
|
||||
|
||||
triton_helpers.make_tensor_descriptor(
|
||||
desc_q = triton_helpers.make_tensor_descriptor(
|
||||
base_ptr=Q,
|
||||
global_shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
|
||||
strides=[QK_HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||
desc_ptr=desc_q,
|
||||
element_ty=Q.dtype.element_ty,
|
||||
desc_ptr=desc_q_ptr,
|
||||
)
|
||||
triton_helpers.make_tensor_descriptor(
|
||||
desc_v = triton_helpers.make_tensor_descriptor(
|
||||
base_ptr=V,
|
||||
global_shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
|
||||
strides=[V_HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
desc_ptr=desc_v,
|
||||
element_ty=K.dtype.element_ty,
|
||||
desc_ptr=desc_v_ptr,
|
||||
)
|
||||
|
||||
triton_helpers.make_tensor_descriptor(
|
||||
desc_k = triton_helpers.make_tensor_descriptor(
|
||||
base_ptr=K,
|
||||
global_shape=[KV_LEN*ZKV*HQ, QK_HEAD_DIM],
|
||||
strides=[QK_HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||
desc_ptr=desc_k,
|
||||
element_ty=K.dtype.element_ty,
|
||||
desc_ptr=desc_k_ptr,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -265,23 +265,21 @@ persistent_tma_mm_template = TritonTemplate(
|
|||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
a_desc = triton_helpers.make_tensor_descriptor(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=A,
|
||||
load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
|
||||
global_size=[M, K] if A_ROW_MAJOR else [K, M],
|
||||
element_ty=A.dtype.element_ty,
|
||||
base_ptr=A,
|
||||
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
|
||||
global_shape=[M, K] if A_ROW_MAJOR else [K, M],
|
||||
)
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
b_desc = triton_helpers.make_tensor_descriptor(
|
||||
desc_ptr=b_desc_ptr,
|
||||
global_address=B,
|
||||
load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
|
||||
global_size=[K, N] if B_ROW_MAJOR else [N, K],
|
||||
element_ty=B.dtype.element_ty,
|
||||
base_ptr=B,
|
||||
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
|
||||
global_shape=[K, N] if B_ROW_MAJOR else [N, K],
|
||||
)
|
||||
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
triton_helpers.tensormap_fenceproxy_acquire(a_desc)
|
||||
triton_helpers.tensormap_fenceproxy_acquire(b_desc)
|
||||
|
||||
pid_m = 0
|
||||
pid_n = 0
|
||||
|
|
@ -303,14 +301,14 @@ persistent_tma_mm_template = TritonTemplate(
|
|||
|
||||
rk = ki * BLOCK_K
|
||||
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr,
|
||||
a = triton_helpers.load_tensor_descriptor(
|
||||
a_desc,
|
||||
[rm, rk] if A_ROW_MAJOR else [rk, rm],
|
||||
[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
|
||||
A.dtype.element_ty,
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr,
|
||||
b = triton_helpers.load_tensor_descriptor(
|
||||
b_desc,
|
||||
[rk, rn] if B_ROW_MAJOR else [rn, rk],
|
||||
[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
|
||||
B.dtype.element_ty,
|
||||
|
|
@ -416,23 +414,21 @@ device_tma = r"""
|
|||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
a_desc = triton_helpers.make_tensor_descriptor(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=A,
|
||||
load_size=[BLOCK_M, BLOCK_K],
|
||||
global_size=[M, K],
|
||||
element_ty=A.dtype.element_ty,
|
||||
base_ptr=A,
|
||||
block_shape=[BLOCK_M, BLOCK_K],
|
||||
global_shape=[M, K],
|
||||
)
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
b_desc = triton_helpers.make_tensor_descriptor(
|
||||
desc_ptr=b_desc_ptr,
|
||||
global_address=B,
|
||||
load_size=[BLOCK_N, BLOCK_K],
|
||||
global_size=[N, K],
|
||||
element_ty=B.dtype.element_ty,
|
||||
base_ptr=B,
|
||||
block_shape=[BLOCK_N, BLOCK_K],
|
||||
global_shape=[N, K],
|
||||
)
|
||||
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
triton_helpers.tensormap_fenceproxy_acquire(a_desc)
|
||||
triton_helpers.tensormap_fenceproxy_acquire(b_desc)
|
||||
|
||||
tiles_per_SM = num_tiles // NUM_SMS
|
||||
if start_pid < num_tiles % NUM_SMS:
|
||||
|
|
@ -465,11 +461,11 @@ device_tma = r"""
|
|||
|
||||
offs_k = ki * BLOCK_K
|
||||
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
|
||||
a = triton_helpers.load_tensor_descriptor(
|
||||
a_desc, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
|
||||
b = triton_helpers.load_tensor_descriptor(
|
||||
b_desc, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
|
||||
)
|
||||
if USE_FAST_ACCUM:
|
||||
accumulator = tl.dot(a, b.T, accumulator)
|
||||
|
|
|
|||
|
|
@ -179,22 +179,20 @@ triton_scaled_grouped_mm_source = r"""
|
|||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_a = triton_helpers.make_tensor_descriptor(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=a_ptr,
|
||||
load_size=[BLOCK_M, BLOCK_K],
|
||||
global_size=[M, K],
|
||||
element_ty=a_ptr.dtype.element_ty,
|
||||
base_ptr=a_ptr,
|
||||
block_shape=[BLOCK_M, BLOCK_K],
|
||||
global_shape=[M, K],
|
||||
)
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_b = triton_helpres.make_tensor_descriptor(
|
||||
desc_ptr=b_desc_ptr,
|
||||
global_address=b_ptr,
|
||||
load_size=[BLOCK_N, BLOCK_K],
|
||||
global_size=[N * G, K],
|
||||
element_ty=b_ptr.dtype.element_ty,
|
||||
base_ptr=b_ptr,
|
||||
block_shape=[BLOCK_N, BLOCK_K],
|
||||
global_shape=[N * G, K],
|
||||
)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
triton_helpers.tensormap_fenceproxy_acquire(desc_a)
|
||||
triton_helpers.tensormap_fenceproxy_acquire(desc_b)
|
||||
|
||||
M_end_offset = 0
|
||||
iterated_tiles = 0
|
||||
|
|
@ -224,14 +222,14 @@ triton_scaled_grouped_mm_source = r"""
|
|||
m_offset = (M_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
|
||||
n_offset = (N_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
|
||||
for k_offset in range(0, K, BLOCK_K):
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr,
|
||||
a = triton_helpers.load_tensor_descriptor(
|
||||
a_desc,
|
||||
[m_offset, k_offset],
|
||||
[BLOCK_M, BLOCK_K],
|
||||
dtype,
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr,
|
||||
b = triton_helpers.load_tensor_descriptor(
|
||||
b_desc,
|
||||
[n_offset, k_offset],
|
||||
[BLOCK_N, BLOCK_K],
|
||||
dtype,
|
||||
|
|
|
|||
|
|
@ -743,15 +743,18 @@ if HAS_NEW_TMA_API:
|
|||
def make_tensor_descriptor(
|
||||
base_ptr,
|
||||
global_shape,
|
||||
strides,
|
||||
block_shape,
|
||||
desc_ptr,
|
||||
element_ty,
|
||||
):
|
||||
# note: this static assert isn't a fundamental limitation,
|
||||
# inferring strides just isn't yet implemented for >2 dimensions.
|
||||
tl.static_assert(
|
||||
len(global_shape) == 2, "Only 2D tensors are supported by current API"
|
||||
)
|
||||
return tl.make_tensor_descriptor(
|
||||
base=base_ptr,
|
||||
shape=global_shape,
|
||||
strides=strides,
|
||||
strides=[global_shape[1], 1],
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
|
@ -784,17 +787,15 @@ else:
|
|||
def make_tensor_descriptor(
|
||||
base_ptr,
|
||||
global_shape,
|
||||
strides,
|
||||
block_shape,
|
||||
desc_ptr,
|
||||
element_ty,
|
||||
):
|
||||
tl.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=desc_ptr,
|
||||
global_address=base_ptr,
|
||||
load_size=block_shape,
|
||||
global_size=global_shape,
|
||||
element_ty=element_ty,
|
||||
element_ty=base_ptr.dtype.element_ty,
|
||||
)
|
||||
|
||||
return desc_ptr
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user