pytorch/torch/testing/_internal/triton_utils.py
Mu-Chu Lee f363fe616d [AOTInductor] Fix autotuning code's codegen (#150522)
Summary:
Codegen used to generate tmp_arg_{index} as temporary args, and index is the position of the caller.
We changed the logic of codegen such that we can reuse previous generated samples, and only delete after arg is no longer used. In this case, we need to make {index} unique, since different functions could reuse the same "tmp_arg_{index}" name string, but corresponds to different args.

Test Plan: `python test/inductor/test_aot_inductor.py -k test_autotuning_args_reuse`

Differential Revision: D72297084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150522
Approved by: https://github.com/desertfire, https://github.com/22quinn
2025-04-03 00:08:19 +00:00

682 lines
20 KiB
Python

# mypy: ignore-errors
import unittest
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
from torch.utils._triton import has_triton
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
if has_triton():
import triton
from triton import language as tl
# Define here so that multiple tests can take advantage of it
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def sub_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x - y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_optional_param(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
ARGS_PASSED: "tl.constexpr",
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if ARGS_PASSED == "two":
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
else:
output = x
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_none_param_and_equal_to_1_arg(
in_ptr0,
in_ptr1, # in_ptr1 could be None
out_ptr,
n_elements,
stride,
ARGS_PASSED: "tl.constexpr",
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets * stride, mask=mask)
if ARGS_PASSED == "two":
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
else:
output = x
tl.store(out_ptr + offsets * stride, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def sub_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x - y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
],
key=[],
)
@triton.jit
def add_kernel_autotuned_weird_param_order(
in_ptr0,
in_ptr1,
n_elements,
BLOCK_SIZE: "tl.constexpr",
out_ptr,
):
# out_ptr is after an autotuned param that's declared as tl.constexpr.
# This param ordering can create bugs if not handled correctly.
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
),
],
key=[],
)
@triton.jit
def add_kernel_2d_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
x_elements,
y_elements,
BLOCK_SIZE_X: "tl.constexpr",
BLOCK_SIZE_Y: "tl.constexpr",
):
xoffset = tl.program_id(0) * BLOCK_SIZE_X
xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
xmask = xindex < x_elements
yoffset = tl.program_id(1) * BLOCK_SIZE_Y
yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
ymask = yindex < y_elements
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
tmp2 = tmp0 + tmp1
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
def _dummy_early_config_prune(configs, *_, **__):
return configs
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
warmup=10,
rep=20,
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
)
@triton.jit
def add_kernel_autotuned_with_unsupported_args(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_scaling(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
scaling_factor,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = (x + y) * scaling_factor
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_tma_1d(
in_desc_ptr0,
in_desc_ptr1,
out_desc_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE
a = tl._experimental_descriptor_load(
in_desc_ptr0,
[offset],
[BLOCK_SIZE],
tl.float32,
)
b = tl._experimental_descriptor_load(
in_desc_ptr1,
[offset],
[BLOCK_SIZE],
tl.float32,
)
output = a + b
tl._experimental_descriptor_store(
out_desc_ptr,
output,
[offset],
)
@triton.jit
def add_kernel_with_tma_2d(
in_desc_ptr0,
in_desc_ptr1,
out_desc_ptr,
BLOCK_SIZE_X: "tl.constexpr",
BLOCK_SIZE_Y: "tl.constexpr",
):
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
offset_x = pid_x * BLOCK_SIZE_X
offset_y = pid_y * BLOCK_SIZE_Y
x = tl._experimental_descriptor_load(
in_desc_ptr0,
[offset_x, offset_y],
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
tl.float32,
)
y = tl._experimental_descriptor_load(
in_desc_ptr1,
[offset_x, offset_y],
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
tl.float32,
)
output = x + y
tl._experimental_descriptor_store(
out_desc_ptr,
output,
[offset_x, offset_y],
)
@triton.jit
def mul2_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = 2 * x
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def mul2_inplace_kernel(
ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(ptr + offsets, mask=mask)
output = 2 * x
tl.store(ptr + offsets, output, mask=mask)
@triton.jit
def zero_negs(x):
return tl.where(x >= 0, x, 0)
@triton.jit
def indirection_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
ACTIVATION: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
if ACTIVATION == "mul2_inplace_kernel":
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
elif ACTIVATION == "add_kernel":
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
x = tl.load(in_ptr0 + offsets, mask=mask)
tl.store(out_ptr + offsets, x, mask=mask)
@triton.jit
def double_strided_kernel(
in_ptr,
out_ptr,
in_y_stride,
out_y_stride,
X_BLOCK_SIZE: "tl.constexpr",
Y_BLOCK_SIZE: "tl.constexpr",
):
xid = tl.program_id(axis=0)
yid = tl.program_id(axis=1)
x_start = xid * X_BLOCK_SIZE
y_start = yid * Y_BLOCK_SIZE
x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
src = tl.load(in_ptr + src_offsets)
tl.store(out_ptr + dst_offsets, src * 2.0)
@triton.jit
def inline_asm_kernel_is_pure_true(
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.load(Y + tl.arange(0, BLOCK))
s = tl.full([BLOCK], n, tl.int32)
z = tl.inline_asm_elementwise(
"shf.l.wrap.b32 $0, $1, $2, $3;",
"=r,r, r, r",
[x, y, s],
dtype=tl.int32,
is_pure=True,
pack=1,
)
tl.store(Z + tl.arange(0, BLOCK), z)
@triton.jit
def inline_asm_kernel_is_pure_false(
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.load(Y + tl.arange(0, BLOCK))
s = tl.full([BLOCK], n, tl.int32)
z = tl.inline_asm_elementwise(
"shf.l.wrap.b32 $0, $1, $2, $3;",
"=r,r, r, r",
[x, y, s],
dtype=tl.int32,
is_pure=False,
pack=1,
)
tl.store(Z + tl.arange(0, BLOCK), z)
@triton.jit
def add_kernel_with_block_ptr(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
x = tl.load(
tl.make_block_ptr(
base=x_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
boundary_check=[0],
)
y = tl.load(
tl.make_block_ptr(
base=y_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
boundary_check=[0],
)
output = x + y
tl.store(
tl.make_block_ptr(
base=output_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
output,
boundary_check=[0],
)
@triton.jit
def kernel_with_block_ptr_2d(
x_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
x = tl.load(
tl.make_block_ptr(
base=x_ptr,
shape=[n_elements, 1],
strides=[1, 1],
offsets=[block_start, 0],
block_shape=[BLOCK_SIZE, 1],
order=[1, 0],
),
boundary_check=[0],
)
output = x
tl.store(
tl.make_block_ptr(
base=output_ptr,
shape=[n_elements, 1],
strides=[1, 1],
offsets=[block_start, 0],
block_shape=[BLOCK_SIZE, 1],
order=[1, 0],
),
output,
boundary_check=[0],
)
from triton.language import load, store
@triton.jit
def add_kernel_with_import(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = load(in_ptr0 + offsets, mask=mask)
y = load(in_ptr1 + offsets, mask=mask)
output = x + y
store(out_ptr + offsets, output, mask=mask)
@triton.jit
def cond_op_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
if tl.program_id(0) == 0:
output = x + y
else:
output = x * y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def atomic_add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.atomic_add(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_4_times_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
for i in range(2):
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
i = 2
while i > 0:
i -= 1
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_out_of_order_fn2(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16,
"GROUP_SIZE_M": 4,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
],
key=["M_ptr", "N", "K"],
)
@triton.jit
def strange_config_matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M_ptr,
N,
K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
# This is a simplified matmul from Triton tutorial.
pid = tl.program_id(axis=0)
M = tl.load(M_ptr)
if M == 0 and BLOCK_SIZE_M > 32:
# This will run the full matmul if BLOCK_SIZE_M > 32
M = 4096
elif M == 0:
# This directly returns, which will cut short the bad config of 16-block size.
return
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :])
b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :])
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)