mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code. Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 | triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 | triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 | After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 | triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 | triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 | Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053 Approved by: https://github.com/mlazos, https://github.com/naromero77amd Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
This commit is contained in:
parent
25909d2629
commit
cdb60e44eb
|
|
@ -628,7 +628,7 @@ class ComboKernel(Kernel):
|
||||||
if heuristics == "foreach":
|
if heuristics == "foreach":
|
||||||
heuristics_line = f"""
|
heuristics_line = f"""
|
||||||
@triton_heuristics.foreach(
|
@triton_heuristics.foreach(
|
||||||
num_warps={self.num_warps},
|
filename=__file__,
|
||||||
triton_meta={triton_meta!r},
|
triton_meta={triton_meta!r},
|
||||||
inductor_meta={inductor_meta!r},
|
inductor_meta={inductor_meta!r},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3530,13 +3530,24 @@ def user_autotune(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||||
"""
|
"""
|
||||||
Compile a triton foreach kernel
|
Compile a triton foreach kernel
|
||||||
"""
|
"""
|
||||||
|
configs = []
|
||||||
|
|
||||||
|
# Naive autotuning path for num_warps
|
||||||
|
if not inductor_meta.get("autotune_pointwise", True) and not (
|
||||||
|
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||||
|
):
|
||||||
|
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||||
|
else:
|
||||||
|
for warps in [1, 2, 4, 8]:
|
||||||
|
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||||
|
|
||||||
return cached_autotune(
|
return cached_autotune(
|
||||||
None,
|
None,
|
||||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
configs,
|
||||||
triton_meta=triton_meta,
|
triton_meta=triton_meta,
|
||||||
inductor_meta=inductor_meta,
|
inductor_meta=inductor_meta,
|
||||||
heuristic_type=HeuristicType.TEMPLATE,
|
heuristic_type=HeuristicType.TEMPLATE,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user