mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[triton_heuristics] Optimize the triton launcher in pt2 (#160000)
Summary:
(Original author: Xu Zhao. Commandeered by David to land this since it is relatively urgent)
We observed ~10us PT2-Triton launch overhead regression after pin update.
Before Triton pin-update:
{F1980557238}
After Triton pin-update:
{F1980557240}
The root cause is because https://github.com/pytorch/pytorch/pull/145051 adds `_get_args_with_constexprs` to the cubin launcher caller function, which is on the critical path.
The motivation for `_get_args_with_constexprs` was that between triton 3.2 and triton 3.3, the convention for calling Triton kernels (at the level that non-static-cuda-launcher inductor integrates) changed. Previously, the callable did not take constexpr arguments as parameters; after 3.3, it does. With pointwise/reduction kernels, we don't know the constexpr values until after autotuning occurs; so `_get_args_with_constexprs` would inject constexprs into the arguments list before calling the Triton kernel. The fix (in this PR) is to instead inject the constexpr args into the launcher string - this avoids the cost of sorting/reordering arguments which previously occurred upon execution of each kernel.
Note that the static_cuda_launcher.py does not require constants to be passed to the cubin launcher (e96c7c4bb0/torch/_inductor/runtime/static_cuda_launcher.py (L220)), there is no need to pass in constexprs to the generated launcher code.
The new launcher code needs to work on three cases:
- StaticallyLaunchedCudaKernel
- triton.compile.CompiledKernel
- AOTInductor
Analysis: https://docs.google.com/document/d/1PHaSmx2w59K8qpjw5_qzKWShfEgptf_Zpv_DL7YxiWU/edit?tab=t.0
Test Plan:
Before:
```
$ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs
1.893x
```
```
$ buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency
x_val nop_python_function-walltime nop_triton_kernel-walltime nop_triton_compiled_kernel_run-walltime nop_inductor_kernel-walltime nop_inductor_kernel_cudagraph-walltime
------- ------------------------------ ---------------------------- ----------------------------------------- ------------------------------ ----------------------------------------
0 0.00760921 1.80298 0.623282 5.25024 0.203722
19 0.00799885 4.78223 1.00226 5.8213 0.239084
average 0.00780403 3.29261 0.812769 5.53577 0.221403
```
After:
```
buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency
x_val nop_python_function-walltime nop_triton_kernel-walltime nop_triton_compiled_kernel_run-walltime nop_inductor_kernel-walltime nop_inductor_kernel_cudagraph-walltime
------- ------------------------------ ---------------------------- ----------------------------------------- ------------------------------ ----------------------------------------
0 0.00747067 1.92589 0.726509 4.35459 0.204205
19 0.00747823 7.36852 1.26241 6.28208 0.239278
average 0.00747445 4.6472 0.994459 5.31834 0.221741
```
```
$ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs
1.985x
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160000
Approved by: https://github.com/jansel
Co-authored-by: Xu Zhao <xzhao9@meta.com>
This commit is contained in:
parent
9ccd0f5e31
commit
d0e2240f68
|
|
@ -6630,6 +6630,9 @@ class UserDefinedTritonKernel(ExternKernel):
|
|||
for name, arg in itertools.chain(
|
||||
named_args.items(), zip(itertools.repeat(""), extra_launch_args)
|
||||
):
|
||||
if name in constexpr_names and triton_version_uses_attrs_dict():
|
||||
# see #160000 - we don't pass in constexpr args to speed up runtime.
|
||||
continue
|
||||
raw_keys_filtered.append(name)
|
||||
raw_args_filtered.append(arg)
|
||||
if isinstance(arg, IRNode):
|
||||
|
|
|
|||
|
|
@ -196,8 +196,7 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
|||
call_kwargs[k] = v
|
||||
else:
|
||||
call_kwargs[k] = v
|
||||
if not triton_version_uses_attrs_dict():
|
||||
call_kwargs.update(launcher.config.kwargs)
|
||||
call_kwargs.update(launcher.config.kwargs)
|
||||
call_kwargs["num_warps"] = launcher.config.num_warps
|
||||
call_kwargs["num_stages"] = launcher.config.num_stages
|
||||
if HAS_WARP_SPEC:
|
||||
|
|
@ -770,28 +769,6 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
|
||||
|
||||
def _get_args_with_constexprs(self, args, launcher):
|
||||
"""
|
||||
`args` is passed in with only the non-constexpr args (because the constexpr arg values
|
||||
depend on the config). However, in later triton versions, the constexpr args need to be
|
||||
added into the args list.
|
||||
"""
|
||||
if triton_version_uses_attrs_dict():
|
||||
# first: aggregate the constexpr args in (index, val) pairs
|
||||
# so we can sort them by index.
|
||||
constexpr_args: list[tuple[int, Any]] = []
|
||||
for arg_name, arg_val in launcher.config.kwargs.items():
|
||||
if arg_name in self.fn.arg_names:
|
||||
constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val))
|
||||
|
||||
constexpr_args.sort()
|
||||
new_args = [*args]
|
||||
for arg_idx, arg_val in constexpr_args:
|
||||
new_args.insert(arg_idx, arg_val)
|
||||
|
||||
return new_args
|
||||
return args
|
||||
|
||||
def bench(self, launcher, *args, with_profiler=False, **kwargs):
|
||||
"""Measure the performance of a given launcher"""
|
||||
# we don't skip configs with spilled registers when auto-tuning custom
|
||||
|
|
@ -820,23 +797,22 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
# reset to zero before evaluating any config
|
||||
self.reset_to_zero_args(*args, **kwargs)
|
||||
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
|
||||
if autograd_profiler._is_profiler_enabled:
|
||||
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
||||
with torch._C._profiler._RecordFunctionFast(
|
||||
self.inductor_meta.get("kernel_name", "triton kernel"),
|
||||
args_with_constexprs,
|
||||
cloned_args,
|
||||
profiler_kwargs,
|
||||
):
|
||||
launcher(
|
||||
*args_with_constexprs,
|
||||
*cloned_args,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
else:
|
||||
launcher(
|
||||
*args_with_constexprs,
|
||||
*cloned_args,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
|
|
@ -1240,7 +1216,6 @@ class CachingAutotuner(KernelInterface):
|
|||
# so _RecordFunctionFast need to capture the args into CachingAutotuner::run()
|
||||
# make a copy here to avoid mutating the original args
|
||||
args_without_constexprs = tuple(args)
|
||||
args = self._get_args_with_constexprs(args, launcher)
|
||||
|
||||
if self.dump_launch_params:
|
||||
new_args, grid = self._interpret_args_grid(args, launcher.config)
|
||||
|
|
@ -1296,6 +1271,10 @@ class _ConstRepr:
|
|||
|
||||
|
||||
class CompileResult(Generic[_T]):
|
||||
"""
|
||||
Base class representing compiled result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel: _T,
|
||||
|
|
@ -1359,21 +1338,30 @@ class CompileResult(Generic[_T]):
|
|||
)
|
||||
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
|
||||
|
||||
def _convert_constant(constant):
|
||||
if isinstance(constant, str):
|
||||
return "r'" + constant + "'"
|
||||
else:
|
||||
return repr(constant)
|
||||
|
||||
if triton_version_uses_attrs_dict():
|
||||
call_args = arg_names
|
||||
def_args = arg_names
|
||||
if (
|
||||
"num_warps" in compile_meta["constants"]
|
||||
or "num_stages" in compile_meta["constants"]
|
||||
implicit_constants = OrderedSet(
|
||||
(
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
)
|
||||
).union(OrderedSet(k for k in known_constants))
|
||||
if implicit_constants := implicit_constants & OrderedSet(
|
||||
compile_meta["constants"].keys()
|
||||
):
|
||||
# num_warps/num_stages are special implicit args that are not in the signature
|
||||
# see test_triton_kernel_special_params
|
||||
def_args = [
|
||||
arg for arg in def_args if arg not in ("num_warps", "num_stages")
|
||||
]
|
||||
def_args = [arg for arg in def_args if arg not in implicit_constants]
|
||||
repl = {
|
||||
k: str(compile_meta["constants"].get(k))
|
||||
for k in ("num_warps", "num_stages")
|
||||
k: _convert_constant(compile_meta["constants"].get(k))
|
||||
for k in implicit_constants
|
||||
}
|
||||
call_args = [repl.get(arg, arg) for arg in call_args]
|
||||
else:
|
||||
|
|
@ -1653,6 +1641,8 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
|||
|
||||
import math as math_lib
|
||||
|
||||
import triton as triton_lib
|
||||
|
||||
import torch as torch_lib
|
||||
|
||||
scope = {
|
||||
|
|
@ -1687,6 +1677,7 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
|||
"runner": get_first_attr(binary, "run", "c_wrapper"),
|
||||
"math": math_lib,
|
||||
"torch": torch_lib,
|
||||
"triton": triton_lib,
|
||||
}
|
||||
|
||||
if not hasattr(binary, "launch_metadata"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user