mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Incorporate ROCm triton specific tuning parameters (#148437)
Splitting https://github.com/pytorch/pytorch/pull/147315 into two PRs. This PR adds general support for kpack and waves_per_eu triton kernel args for AMD backend. More detail in the PR above. A follow up PR will update the configs used by ROCm but this requires https://github.com/pytorch/pytorch/pull/147452 to land first Pull Request resolved: https://github.com/pytorch/pytorch/pull/148437 Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
parent
a3b77d434a
commit
8059ead823
|
|
@ -643,6 +643,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
num_stages: int,
|
||||
num_warps: int,
|
||||
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
||||
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
|
||||
kpack: int = 0, # ROCm specific gemm paramete
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
) -> None:
|
||||
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
||||
|
|
@ -652,6 +654,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
||||
self.waves_per_eu = waves_per_eu
|
||||
self.kpack = kpack
|
||||
self.workspace_arg = workspace_arg
|
||||
|
||||
def make_run_fn(
|
||||
|
|
|
|||
|
|
@ -1449,6 +1449,10 @@ def flex_attention(
|
|||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
|
||||
# ROCm specific considerations
|
||||
if torch.version.hip:
|
||||
kernel_options["kpack"] = 2
|
||||
|
||||
# Note, we don't need to pass in the captured buffers explicitly
|
||||
# because they're implicitly added by the score_mod function
|
||||
# We do need to explicitly pass it in for autotuning though.
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def filtered_configs(
|
|||
),
|
||||
min_block_size_k,
|
||||
)
|
||||
used = OrderedSet[tuple[int, int, int, int, int, int]]()
|
||||
used = OrderedSet[tuple[int, ...]]()
|
||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||
# shrink configs for small sizes
|
||||
block_m = max(min(int(block_m * scale), m), min_block_size)
|
||||
|
|
@ -102,6 +102,7 @@ def filtered_configs(
|
|||
# each warp computes 16x16 tile = 256
|
||||
num_warps = min(num_warps, block_m * block_n // 256)
|
||||
if torch.version.hip:
|
||||
kpack = 2
|
||||
for matrix_instr_nonkdim in [0, 16]:
|
||||
if matrix_instr_nonkdim != 0 and (
|
||||
block_m % matrix_instr_nonkdim != 0
|
||||
|
|
@ -109,6 +110,7 @@ def filtered_configs(
|
|||
):
|
||||
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
||||
continue
|
||||
|
||||
if (
|
||||
block_m,
|
||||
block_n,
|
||||
|
|
@ -116,6 +118,7 @@ def filtered_configs(
|
|||
num_stages,
|
||||
num_warps,
|
||||
matrix_instr_nonkdim,
|
||||
kpack,
|
||||
) not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
|
|
@ -127,6 +130,7 @@ def filtered_configs(
|
|||
num_stages,
|
||||
num_warps,
|
||||
matrix_instr_nonkdim,
|
||||
kpack,
|
||||
)
|
||||
)
|
||||
yield triton_config(
|
||||
|
|
@ -136,6 +140,7 @@ def filtered_configs(
|
|||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
kpack=kpack,
|
||||
)
|
||||
else:
|
||||
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ def get_field(config, name):
|
|||
return config.num_warps
|
||||
elif name == "num_stages":
|
||||
return config.num_stages
|
||||
elif name == "waves_per_eu":
|
||||
return config.kwargs.get(name, int(8 // config.num_warps))
|
||||
else:
|
||||
return config.kwargs.get(name, None)
|
||||
|
||||
|
|
@ -97,6 +99,8 @@ class CoordescTuner:
|
|||
]
|
||||
if self.is_mm:
|
||||
out.append("num_stages")
|
||||
if self.inductor_meta.get("is_hip") is True:
|
||||
out.append("waves_per_eu")
|
||||
|
||||
return out
|
||||
|
||||
|
|
@ -107,6 +111,8 @@ class CoordescTuner:
|
|||
return val > self.get_config_max(prefix)
|
||||
if name == "num_warps":
|
||||
return val > self.get_warpsmax()
|
||||
if name == "waves_per_eu":
|
||||
return val > 8
|
||||
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -433,9 +433,15 @@ class TritonTemplateKernel(TritonKernel):
|
|||
triton_meta["configs"] = [config_of(signature)]
|
||||
for arg_num in equal_1_arg_indices(signature): # type: ignore[index]
|
||||
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr]
|
||||
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
|
||||
if matrix_instr_nonkdim != 0:
|
||||
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None)
|
||||
waves_per_eu = self.meta.get("waves_per_eu", None)
|
||||
kpack = self.meta.get("kpack", None)
|
||||
if matrix_instr_nonkdim:
|
||||
triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
|
||||
if waves_per_eu:
|
||||
triton_meta["waves_per_eu"] = waves_per_eu
|
||||
if kpack:
|
||||
triton_meta["kpack"] = kpack
|
||||
|
||||
self.triton_meta = triton_meta
|
||||
|
||||
|
|
@ -1215,6 +1221,8 @@ class TritonTemplate(KernelTemplate):
|
|||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
||||
waves_per_eu=kwargs.get("waves_per_eu", 0),
|
||||
kpack=kwargs.get("kpack", 2),
|
||||
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
|
||||
output_tensor_meta=TensorMeta.from_irnodes(layout),
|
||||
workspace_arg=workspace_arg,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user