Revert "[triton_heuristics] Optimize the triton launcher in pt2 (#160000)"

This reverts commit d0e2240f68.

Reverted https://github.com/pytorch/pytorch/pull/160000 on behalf of https://github.com/davidberard98 due to D80054972 failing with test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1_tdlp_1 ([comment](https://github.com/pytorch/pytorch/pull/160000#issuecomment-3180144676))
This commit is contained in:
PyTorch MergeBot 2025-08-12 16:33:02 +00:00
parent 9d37c960a4
commit f7b2f3314c
2 changed files with 37 additions and 31 deletions

View File

@ -6630,9 +6630,6 @@ class UserDefinedTritonKernel(ExternKernel):
for name, arg in itertools.chain(
named_args.items(), zip(itertools.repeat(""), extra_launch_args)
):
if name in constexpr_names and triton_version_uses_attrs_dict():
# see #160000 - we don't pass in constexpr args to speed up runtime.
continue
raw_keys_filtered.append(name)
raw_args_filtered.append(arg)
if isinstance(arg, IRNode):

View File

@ -196,7 +196,8 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
call_kwargs[k] = v
else:
call_kwargs[k] = v
call_kwargs.update(launcher.config.kwargs)
if not triton_version_uses_attrs_dict():
call_kwargs.update(launcher.config.kwargs)
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
if HAS_WARP_SPEC:
@ -769,6 +770,28 @@ class CachingAutotuner(KernelInterface):
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
def _get_args_with_constexprs(self, args, launcher):
"""
`args` is passed in with only the non-constexpr args (because the constexpr arg values
depend on the config). However, in later triton versions, the constexpr args need to be
added into the args list.
"""
if triton_version_uses_attrs_dict():
# first: aggregate the constexpr args in (index, val) pairs
# so we can sort them by index.
constexpr_args: list[tuple[int, Any]] = []
for arg_name, arg_val in launcher.config.kwargs.items():
if arg_name in self.fn.arg_names:
constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val))
constexpr_args.sort()
new_args = [*args]
for arg_idx, arg_val in constexpr_args:
new_args.insert(arg_idx, arg_val)
return new_args
return args
def bench(self, launcher, *args, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
# we don't skip configs with spilled registers when auto-tuning custom
@ -797,22 +820,23 @@ class CachingAutotuner(KernelInterface):
)
# reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs)
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
if autograd_profiler._is_profiler_enabled:
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
with torch._C._profiler._RecordFunctionFast(
self.inductor_meta.get("kernel_name", "triton kernel"),
cloned_args,
args_with_constexprs,
profiler_kwargs,
):
launcher(
*cloned_args,
*args_with_constexprs,
**cloned_kwargs,
stream=stream,
)
else:
launcher(
*cloned_args,
*args_with_constexprs,
**cloned_kwargs,
stream=stream,
)
@ -1216,6 +1240,7 @@ class CachingAutotuner(KernelInterface):
# so _RecordFunctionFast need to capture the args into CachingAutotuner::run()
# make a copy here to avoid mutating the original args
args_without_constexprs = tuple(args)
args = self._get_args_with_constexprs(args, launcher)
if self.dump_launch_params:
new_args, grid = self._interpret_args_grid(args, launcher.config)
@ -1271,10 +1296,6 @@ class _ConstRepr:
class CompileResult(Generic[_T]):
"""
Base class representing compiled result.
"""
def __init__(
self,
kernel: _T,
@ -1338,30 +1359,21 @@ class CompileResult(Generic[_T]):
)
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
def _convert_constant(constant):
if isinstance(constant, str):
return "r'" + constant + "'"
else:
return repr(constant)
if triton_version_uses_attrs_dict():
call_args = arg_names
def_args = arg_names
implicit_constants = OrderedSet(
(
"num_warps",
"num_stages",
)
).union(OrderedSet(k for k in known_constants))
if implicit_constants := implicit_constants & OrderedSet(
compile_meta["constants"].keys()
if (
"num_warps" in compile_meta["constants"]
or "num_stages" in compile_meta["constants"]
):
# num_warps/num_stages are special implicit args that are not in the signature
# see test_triton_kernel_special_params
def_args = [arg for arg in def_args if arg not in implicit_constants]
def_args = [
arg for arg in def_args if arg not in ("num_warps", "num_stages")
]
repl = {
k: _convert_constant(compile_meta["constants"].get(k))
for k in implicit_constants
k: str(compile_meta["constants"].get(k))
for k in ("num_warps", "num_stages")
}
call_args = [repl.get(arg, arg) for arg in call_args]
else:
@ -1641,8 +1653,6 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
import math as math_lib
import triton as triton_lib
import torch as torch_lib
scope = {
@ -1677,7 +1687,6 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
"runner": get_first_attr(binary, "run", "c_wrapper"),
"math": math_lib,
"torch": torch_lib,
"triton": triton_lib,
}
if not hasattr(binary, "launch_metadata"):