[triton_heuristics] Optimize the triton launcher in pt2 (#160000)

Summary:

(Original author: Xu Zhao. Commandeered by David to land this since it is relatively urgent)

We observed ~10us PT2-Triton launch overhead regression after pin update.

Before Triton pin-update:
 {F1980557238}

After Triton pin-update:
 {F1980557240}

The root cause is because https://github.com/pytorch/pytorch/pull/145051 adds `_get_args_with_constexprs` to the cubin launcher caller function, which is on the critical path.

The motivation for `_get_args_with_constexprs` was that between triton 3.2 and triton 3.3, the convention for calling Triton kernels (at the level that non-static-cuda-launcher inductor integrates) changed. Previously, the callable did not take constexpr arguments as parameters; after 3.3, it does. With pointwise/reduction kernels, we don't know the constexpr values until after autotuning occurs; so `_get_args_with_constexprs` would inject constexprs into the arguments list before calling the Triton kernel. The fix (in this PR) is to instead inject the constexpr args into the launcher string - this avoids the cost of sorting/reordering arguments which previously occurred upon execution of each kernel.

Note that the static_cuda_launcher.py does not require constants to be passed to the cubin launcher (e96c7c4bb0/torch/_inductor/runtime/static_cuda_launcher.py (L220)), there is no need to pass in constexprs to the generated launcher code.

The new launcher code needs to work on three cases:
- StaticallyLaunchedCudaKernel
- triton.compile.CompiledKernel
- AOTInductor

Analysis: https://docs.google.com/document/d/1PHaSmx2w59K8qpjw5_qzKWShfEgptf_Zpv_DL7YxiWU/edit?tab=t.0

Test Plan:
Before:
```
$ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs

1.893x
```

```

$ buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency
  x_val    nop_python_function-walltime    nop_triton_kernel-walltime    nop_triton_compiled_kernel_run-walltime    nop_inductor_kernel-walltime    nop_inductor_kernel_cudagraph-walltime
-------  ------------------------------  ----------------------------  -----------------------------------------  ------------------------------  ----------------------------------------
      0                      0.00760921                       1.80298                                   0.623282                         5.25024                                  0.203722
     19                      0.00799885                       4.78223                                   1.00226                          5.8213                                   0.239084
average                      0.00780403                       3.29261                                   0.812769                         5.53577                                  0.221403
```

After:

```
buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency
  x_val    nop_python_function-walltime    nop_triton_kernel-walltime    nop_triton_compiled_kernel_run-walltime    nop_inductor_kernel-walltime    nop_inductor_kernel_cudagraph-walltime
-------  ------------------------------  ----------------------------  -----------------------------------------  ------------------------------  ----------------------------------------
      0                      0.00747067                       1.92589                                   0.726509                         4.35459                                  0.204205
     19                      0.00747823                       7.36852                                   1.26241                          6.28208                                  0.239278
average                      0.00747445                       4.6472                                    0.994459                         5.31834                                  0.221741
```

```
$ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs

1.985x
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160000
Approved by: https://github.com/jansel

Co-authored-by: Xu Zhao <xzhao9@meta.com>
This commit is contained in:
David Berard 2025-08-11 17:22:40 +00:00 committed by PyTorch MergeBot
parent 9ccd0f5e31
commit d0e2240f68
2 changed files with 31 additions and 37 deletions

View File

@ -6630,6 +6630,9 @@ class UserDefinedTritonKernel(ExternKernel):
for name, arg in itertools.chain(
named_args.items(), zip(itertools.repeat(""), extra_launch_args)
):
if name in constexpr_names and triton_version_uses_attrs_dict():
# see #160000 - we don't pass in constexpr args to speed up runtime.
continue
raw_keys_filtered.append(name)
raw_args_filtered.append(arg)
if isinstance(arg, IRNode):

View File

@ -196,8 +196,7 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
call_kwargs[k] = v
else:
call_kwargs[k] = v
if not triton_version_uses_attrs_dict():
call_kwargs.update(launcher.config.kwargs)
call_kwargs.update(launcher.config.kwargs)
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
if HAS_WARP_SPEC:
@ -770,28 +769,6 @@ class CachingAutotuner(KernelInterface):
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
def _get_args_with_constexprs(self, args, launcher):
"""
`args` is passed in with only the non-constexpr args (because the constexpr arg values
depend on the config). However, in later triton versions, the constexpr args need to be
added into the args list.
"""
if triton_version_uses_attrs_dict():
# first: aggregate the constexpr args in (index, val) pairs
# so we can sort them by index.
constexpr_args: list[tuple[int, Any]] = []
for arg_name, arg_val in launcher.config.kwargs.items():
if arg_name in self.fn.arg_names:
constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val))
constexpr_args.sort()
new_args = [*args]
for arg_idx, arg_val in constexpr_args:
new_args.insert(arg_idx, arg_val)
return new_args
return args
def bench(self, launcher, *args, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
# we don't skip configs with spilled registers when auto-tuning custom
@ -820,23 +797,22 @@ class CachingAutotuner(KernelInterface):
)
# reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs)
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
if autograd_profiler._is_profiler_enabled:
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
with torch._C._profiler._RecordFunctionFast(
self.inductor_meta.get("kernel_name", "triton kernel"),
args_with_constexprs,
cloned_args,
profiler_kwargs,
):
launcher(
*args_with_constexprs,
*cloned_args,
**cloned_kwargs,
stream=stream,
)
else:
launcher(
*args_with_constexprs,
*cloned_args,
**cloned_kwargs,
stream=stream,
)
@ -1240,7 +1216,6 @@ class CachingAutotuner(KernelInterface):
# so _RecordFunctionFast need to capture the args into CachingAutotuner::run()
# make a copy here to avoid mutating the original args
args_without_constexprs = tuple(args)
args = self._get_args_with_constexprs(args, launcher)
if self.dump_launch_params:
new_args, grid = self._interpret_args_grid(args, launcher.config)
@ -1296,6 +1271,10 @@ class _ConstRepr:
class CompileResult(Generic[_T]):
"""
Base class representing compiled result.
"""
def __init__(
self,
kernel: _T,
@ -1359,21 +1338,30 @@ class CompileResult(Generic[_T]):
)
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
def _convert_constant(constant):
if isinstance(constant, str):
return "r'" + constant + "'"
else:
return repr(constant)
if triton_version_uses_attrs_dict():
call_args = arg_names
def_args = arg_names
if (
"num_warps" in compile_meta["constants"]
or "num_stages" in compile_meta["constants"]
implicit_constants = OrderedSet(
(
"num_warps",
"num_stages",
)
).union(OrderedSet(k for k in known_constants))
if implicit_constants := implicit_constants & OrderedSet(
compile_meta["constants"].keys()
):
# num_warps/num_stages are special implicit args that are not in the signature
# see test_triton_kernel_special_params
def_args = [
arg for arg in def_args if arg not in ("num_warps", "num_stages")
]
def_args = [arg for arg in def_args if arg not in implicit_constants]
repl = {
k: str(compile_meta["constants"].get(k))
for k in ("num_warps", "num_stages")
k: _convert_constant(compile_meta["constants"].get(k))
for k in implicit_constants
}
call_args = [repl.get(arg, arg) for arg in call_args]
else:
@ -1653,6 +1641,8 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
import math as math_lib
import triton as triton_lib
import torch as torch_lib
scope = {
@ -1687,6 +1677,7 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
"runner": get_first_attr(binary, "run", "c_wrapper"),
"math": math_lib,
"torch": torch_lib,
"triton": triton_lib,
}
if not hasattr(binary, "launch_metadata"):