[ROCm] Incorporate ROCm triton specific tuning parameters (#148437)

Splitting https://github.com/pytorch/pytorch/pull/147315 into two PRs. This PR adds general support for kpack and waves_per_eu triton kernel args for AMD backend. More detail in the PR above.

A follow up PR will update the configs used by ROCm but this requires https://github.com/pytorch/pytorch/pull/147452 to land first

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148437
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Jack Taylor 2025-03-07 18:09:42 +00:00 committed by PyTorch MergeBot
parent a3b77d434a
commit 8059ead823
5 changed files with 30 additions and 3 deletions

View File

@ -643,6 +643,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
num_stages: int,
num_warps: int,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
kpack: int = 0, # ROCm specific gemm paramete
workspace_arg: Optional[WorkspaceArg] = None,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
@ -652,6 +654,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
self.num_stages = num_stages
self.num_warps = num_warps
self.matrix_instr_nonkdim = matrix_instr_nonkdim
self.waves_per_eu = waves_per_eu
self.kpack = kpack
self.workspace_arg = workspace_arg
def make_run_fn(

View File

@ -1449,6 +1449,10 @@ def flex_attention(
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
# ROCm specific considerations
if torch.version.hip:
kernel_options["kpack"] = 2
# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
# We do need to explicitly pass it in for autotuning though.

View File

@ -89,7 +89,7 @@ def filtered_configs(
),
min_block_size_k,
)
used = OrderedSet[tuple[int, int, int, int, int, int]]()
used = OrderedSet[tuple[int, ...]]()
for block_m, block_n, block_k, num_stages, num_warps in configs:
# shrink configs for small sizes
block_m = max(min(int(block_m * scale), m), min_block_size)
@ -102,6 +102,7 @@ def filtered_configs(
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
if torch.version.hip:
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
@ -109,6 +110,7 @@ def filtered_configs(
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
block_m,
block_n,
@ -116,6 +118,7 @@ def filtered_configs(
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
@ -127,6 +130,7 @@ def filtered_configs(
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
)
)
yield triton_config(
@ -136,6 +140,7 @@ def filtered_configs(
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
else:
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (

View File

@ -20,6 +20,8 @@ def get_field(config, name):
return config.num_warps
elif name == "num_stages":
return config.num_stages
elif name == "waves_per_eu":
return config.kwargs.get(name, int(8 // config.num_warps))
else:
return config.kwargs.get(name, None)
@ -97,6 +99,8 @@ class CoordescTuner:
]
if self.is_mm:
out.append("num_stages")
if self.inductor_meta.get("is_hip") is True:
out.append("waves_per_eu")
return out
@ -107,6 +111,8 @@ class CoordescTuner:
return val > self.get_config_max(prefix)
if name == "num_warps":
return val > self.get_warpsmax()
if name == "waves_per_eu":
return val > 8
return False

View File

@ -433,9 +433,15 @@ class TritonTemplateKernel(TritonKernel):
triton_meta["configs"] = [config_of(signature)]
for arg_num in equal_1_arg_indices(signature): # type: ignore[index]
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
if matrix_instr_nonkdim != 0:
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None)
waves_per_eu = self.meta.get("waves_per_eu", None)
kpack = self.meta.get("kpack", None)
if matrix_instr_nonkdim:
triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
if waves_per_eu:
triton_meta["waves_per_eu"] = waves_per_eu
if kpack:
triton_meta["kpack"] = kpack
self.triton_meta = triton_meta
@ -1215,6 +1221,8 @@ class TritonTemplate(KernelTemplate):
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
waves_per_eu=kwargs.get("waves_per_eu", 0),
kpack=kwargs.get("kpack", 2),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
output_tensor_meta=TensorMeta.from_irnodes(layout),
workspace_arg=workspace_arg,