mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[triton_heuristics] Optimize the triton launcher in pt2 (#160000)"
This reverts commit d0e2240f68.
Reverted https://github.com/pytorch/pytorch/pull/160000 on behalf of https://github.com/davidberard98 due to D80054972 failing with test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1_tdlp_1 ([comment](https://github.com/pytorch/pytorch/pull/160000#issuecomment-3180144676))
This commit is contained in:
parent
9d37c960a4
commit
f7b2f3314c
|
|
@ -6630,9 +6630,6 @@ class UserDefinedTritonKernel(ExternKernel):
|
|||
for name, arg in itertools.chain(
|
||||
named_args.items(), zip(itertools.repeat(""), extra_launch_args)
|
||||
):
|
||||
if name in constexpr_names and triton_version_uses_attrs_dict():
|
||||
# see #160000 - we don't pass in constexpr args to speed up runtime.
|
||||
continue
|
||||
raw_keys_filtered.append(name)
|
||||
raw_args_filtered.append(arg)
|
||||
if isinstance(arg, IRNode):
|
||||
|
|
|
|||
|
|
@ -196,7 +196,8 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
|||
call_kwargs[k] = v
|
||||
else:
|
||||
call_kwargs[k] = v
|
||||
call_kwargs.update(launcher.config.kwargs)
|
||||
if not triton_version_uses_attrs_dict():
|
||||
call_kwargs.update(launcher.config.kwargs)
|
||||
call_kwargs["num_warps"] = launcher.config.num_warps
|
||||
call_kwargs["num_stages"] = launcher.config.num_stages
|
||||
if HAS_WARP_SPEC:
|
||||
|
|
@ -769,6 +770,28 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
|
||||
|
||||
def _get_args_with_constexprs(self, args, launcher):
|
||||
"""
|
||||
`args` is passed in with only the non-constexpr args (because the constexpr arg values
|
||||
depend on the config). However, in later triton versions, the constexpr args need to be
|
||||
added into the args list.
|
||||
"""
|
||||
if triton_version_uses_attrs_dict():
|
||||
# first: aggregate the constexpr args in (index, val) pairs
|
||||
# so we can sort them by index.
|
||||
constexpr_args: list[tuple[int, Any]] = []
|
||||
for arg_name, arg_val in launcher.config.kwargs.items():
|
||||
if arg_name in self.fn.arg_names:
|
||||
constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val))
|
||||
|
||||
constexpr_args.sort()
|
||||
new_args = [*args]
|
||||
for arg_idx, arg_val in constexpr_args:
|
||||
new_args.insert(arg_idx, arg_val)
|
||||
|
||||
return new_args
|
||||
return args
|
||||
|
||||
def bench(self, launcher, *args, with_profiler=False, **kwargs):
|
||||
"""Measure the performance of a given launcher"""
|
||||
# we don't skip configs with spilled registers when auto-tuning custom
|
||||
|
|
@ -797,22 +820,23 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
# reset to zero before evaluating any config
|
||||
self.reset_to_zero_args(*args, **kwargs)
|
||||
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
|
||||
if autograd_profiler._is_profiler_enabled:
|
||||
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
||||
with torch._C._profiler._RecordFunctionFast(
|
||||
self.inductor_meta.get("kernel_name", "triton kernel"),
|
||||
cloned_args,
|
||||
args_with_constexprs,
|
||||
profiler_kwargs,
|
||||
):
|
||||
launcher(
|
||||
*cloned_args,
|
||||
*args_with_constexprs,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
else:
|
||||
launcher(
|
||||
*cloned_args,
|
||||
*args_with_constexprs,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
|
|
@ -1216,6 +1240,7 @@ class CachingAutotuner(KernelInterface):
|
|||
# so _RecordFunctionFast need to capture the args into CachingAutotuner::run()
|
||||
# make a copy here to avoid mutating the original args
|
||||
args_without_constexprs = tuple(args)
|
||||
args = self._get_args_with_constexprs(args, launcher)
|
||||
|
||||
if self.dump_launch_params:
|
||||
new_args, grid = self._interpret_args_grid(args, launcher.config)
|
||||
|
|
@ -1271,10 +1296,6 @@ class _ConstRepr:
|
|||
|
||||
|
||||
class CompileResult(Generic[_T]):
|
||||
"""
|
||||
Base class representing compiled result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel: _T,
|
||||
|
|
@ -1338,30 +1359,21 @@ class CompileResult(Generic[_T]):
|
|||
)
|
||||
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
|
||||
|
||||
def _convert_constant(constant):
|
||||
if isinstance(constant, str):
|
||||
return "r'" + constant + "'"
|
||||
else:
|
||||
return repr(constant)
|
||||
|
||||
if triton_version_uses_attrs_dict():
|
||||
call_args = arg_names
|
||||
def_args = arg_names
|
||||
implicit_constants = OrderedSet(
|
||||
(
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
)
|
||||
).union(OrderedSet(k for k in known_constants))
|
||||
if implicit_constants := implicit_constants & OrderedSet(
|
||||
compile_meta["constants"].keys()
|
||||
if (
|
||||
"num_warps" in compile_meta["constants"]
|
||||
or "num_stages" in compile_meta["constants"]
|
||||
):
|
||||
# num_warps/num_stages are special implicit args that are not in the signature
|
||||
# see test_triton_kernel_special_params
|
||||
def_args = [arg for arg in def_args if arg not in implicit_constants]
|
||||
def_args = [
|
||||
arg for arg in def_args if arg not in ("num_warps", "num_stages")
|
||||
]
|
||||
repl = {
|
||||
k: _convert_constant(compile_meta["constants"].get(k))
|
||||
for k in implicit_constants
|
||||
k: str(compile_meta["constants"].get(k))
|
||||
for k in ("num_warps", "num_stages")
|
||||
}
|
||||
call_args = [repl.get(arg, arg) for arg in call_args]
|
||||
else:
|
||||
|
|
@ -1641,8 +1653,6 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
|||
|
||||
import math as math_lib
|
||||
|
||||
import triton as triton_lib
|
||||
|
||||
import torch as torch_lib
|
||||
|
||||
scope = {
|
||||
|
|
@ -1677,7 +1687,6 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
|||
"runner": get_first_attr(binary, "run", "c_wrapper"),
|
||||
"math": math_lib,
|
||||
"torch": torch_lib,
|
||||
"triton": triton_lib,
|
||||
}
|
||||
|
||||
if not hasattr(binary, "launch_metadata"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user