[user triton] add support for prune_configs_by in @triton.autotune (#142207)

This PR adds support for prune_configs_by in the @triton.autotune decorator [docs](https://triton-lang.org/main/python-api/generated/triton.autotune.html#triton.autotune). Supporting this lets users reduce autotuning time by running user-supplied code (early_config_prune, perf_model) to prune the provided list of configs.

We implement this by realizing args/kwargs in call_triton_kernel(...), and then calling kernel.prune_configs(...).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142207
Approved by: https://github.com/zou3519, https://github.com/aakhundov
This commit is contained in:
Sam Ginzburg 2025-01-02 20:15:18 -08:00 committed by PyTorch MergeBot
parent 479d6f2199
commit ec1f56fdcf
5 changed files with 456 additions and 26 deletions

View File

@ -1920,12 +1920,14 @@ def forward(self, arg0_1, arg1_1):
def kernel(X):
return
@torch.compile(backend=backend)
@torch.compile(fullgraph=True, backend=backend)
def f(x):
kernel[(1,)](x, num_ctas=1)
kernel.run(x, num_ctas=1, grid=(1,), warmup=False)
return x
msg = "Passing num_ctas directly to the Triton kernel is not supported. Please use a Config in @triton.autotune instead."
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.randn(4, device=GPU_TYPE)
f(x)
@ -3649,13 +3651,198 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
# this should cause an exception, since pre_hook is not allowed
msg = "Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported. "
msg = "Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported."
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
add_compiled = torch.compile(
add, mode="reduce-overhead", fullgraph=True, backend=backend
)
add_compiled(x, y).mean()
@requires_gpu
@common_utils.parametrize("non_strict", [True, False])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("with_perf_model", [True, False])
def test_triton_kernel_prune_configs_by(self, backend, with_perf_model, non_strict):
# for non-strict mode
libname = "my_cool_namespace"
opname = "my_triton_operator"
records = {}
def early_config_prune(configs, named_args, **kwargs):
# we need to save the records to the returned config
records["run_early_config_prune"] = True
if "N" in kwargs and kwargs["N"] == 1024:
records["capture_kwargs"] = True
# named args are: dst, src, add_float
if "dst" in named_args and "src" in named_args and len(named_args) == 3:
records["capture_named_args"] = True
return [configs[0]]
def perf_model(*args, **kwargs):
records["run_perf_model"] = True
return kwargs["BLOCK_SIZE"] * -1
if with_perf_model:
prune_configs_by = {"perf_model": perf_model, "top_k": 1}
else:
prune_configs_by = {"early_config_prune": early_config_prune}
@triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 32}),
triton.Config(kwargs={"BLOCK_SIZE": 128}),
],
key=["N"],
prune_configs_by=prune_configs_by,
)
@triton.jit
def prune_by_kernel(
dst,
src,
add_float,
N,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
# we only modify dst if our perf_model is applied (and a BLOCK_SIZE of 128 is selected)
if BLOCK_SIZE == 128:
x = x + add_float
tl.store(dst + offsets, x, mask=offsets < N)
def f(
dst: torch.Tensor,
src: torch.Tensor,
add_float: float,
N: int,
) -> None:
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
if non_strict:
torch.library.wrap_triton(prune_by_kernel)[grid](
dst, src, add_float, N=N
)
else:
prune_by_kernel[grid](dst, src, add_float, N=N)
if non_strict:
decorator = torch.library.triton_op(
f"{libname}::{opname}", mutates_args={"dst"}
)(f)
else:
# we can just pass the function 'f' for dynamo
decorator = f
compiled_f = torch.compile(decorator, backend=backend)
N = 1024
src = torch.randn(N, device=GPU_TYPE)
dst = torch.empty(N, device=GPU_TYPE)
compiled_f(dst, src, 1.5, N)
if with_perf_model:
# when applying the perf_model: kwargs["BLOCK_SIZE"] * -1, the largest config (BLOCK_SIZE==128) is selected
self.assertEqual(len(records), 1)
self.assertEqual(src + 1.5, dst)
else:
# without the perf_model, the BLOCK_SIZE==32, and as a result dst is not modified and remains equal to src
self.assertEqual(src, dst)
self.assertEqual(len(records), 3)
self.assertTrue(records["run_early_config_prune"])
self.assertTrue(records["capture_kwargs"])
self.assertTrue(records["capture_named_args"])
@requires_gpu
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("with_perf_model", [True, False])
def test_triton_kernel_prune_configs_by_recompile(self, backend, with_perf_model):
"""
We want to recompile if anyone changes configs in the autotuner object
In short if for example the following sequence of events happens:
1. foo = torch.compile(bar)
1. call foo
2. autotuner.configs = [new configs list]
3. call foo
A recompile event should occur, which we check with Dynamo counters
This tests that we are installing guards on input objects properly
"""
# We don't modify records here because we are testing whether or not
# recompiles occur/guards are installed
# If we modified the non-local records dict here, this would trigger
# recompile events.
def early_config_prune(configs, named_args, **kwargs):
return [configs[0]]
def perf_model(*args, **kwargs):
return kwargs["BLOCK_SIZE"] * -1
if with_perf_model:
prune_configs_by = {"perf_model": perf_model, "top_k": 1}
else:
prune_configs_by = {"early_config_prune": early_config_prune}
@triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 32}),
triton.Config(kwargs={"BLOCK_SIZE": 128}),
],
key=["N"],
prune_configs_by=prune_configs_by,
)
@triton.jit
def prune_by_kernel(
dst,
src,
add_float,
N,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
# Let's make sure we always select a block size of 128 based on our perf_model
if BLOCK_SIZE == 128:
x = x + add_float
tl.store(dst + offsets, x, mask=offsets < N)
torch._dynamo.reset()
counter = torch._dynamo.testing.CompileCounterWithBackend(backend=backend)
@torch.compile(fullgraph=True, backend=counter)
def f(dst, src, add_float, N):
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
prune_by_kernel[grid](dst, src, add_float, N=N)
N = 1024
src = torch.randn(N, device=GPU_TYPE)
dst = torch.empty(N, device=GPU_TYPE)
# first compilation, this prunes the configs
f(dst, src, 1.5, N)
self.assertEqual(counter.op_count, 1)
f(dst, src, 1.5, N)
# this should not trigger a recompilation
# this is because we modified the test to not touch the records dict
# as we do in test_triton_kernel_prune_configs_by. If we kept it, it would trigger a recompile here.
self.assertEqual(counter.op_count, 1)
# Modify the autotuner object
prune_by_kernel.configs = [triton.Config(kwargs={"BLOCK_SIZE": 64})]
# Calling the kernel after modifying the autotuner should
# trigger a recompile
f(dst, src, 1.5, N)
self.assertEqual(counter.op_count, 2)
# there should be no recompile here
f(dst, src, 1.5, N)
self.assertEqual(counter.op_count, 2)
common_utils.instantiate_parametrized_tests(KernelTests)
common_utils.instantiate_parametrized_tests(CustomOpTests)

View File

@ -1646,7 +1646,7 @@ class GuardBuilder(GuardBuilderBase):
return self.FUNCTION_MATCH(guard)
def SEQUENCE_LENGTH(self, guard):
# This guard is used to check lenght of PySequence objects like list,
# This guard is used to check length of PySequence objects like list,
# tuple, collections.deque etc
ref = self.arg_ref(guard)
value = self.get(guard.name)

View File

@ -294,6 +294,7 @@ manual_torch_name_rule_map = {
"torch._functorch.deprecated.grad_and_value": UserFunctionVariable,
"torch._functorch.deprecated.vjp": UserFunctionVariable,
# everything else
"torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable,
"torch._higher_order_ops.foreach_map.foreach_map": UserFunctionVariable,
"torch._constrain_as_size": UserFunctionVariable,
"torch._tensor._convert": UserFunctionVariable,
@ -3298,6 +3299,7 @@ MOD_INLINELIST = [
"torch._higher_order_ops.invoke_subgraph",
"torch._higher_order_ops.scan",
"torch._higher_order_ops.strict_mode",
"torch._higher_order_ops.triton_kernel_wrap",
"torch._higher_order_ops.while_loop",
"torch._inductor.test_operators",
"torch._library.autograd",

View File

@ -6,7 +6,17 @@ import functools
import inspect
import itertools
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, TypeVar
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
)
from typing_extensions import Never
from unittest.mock import patch
@ -1098,6 +1108,52 @@ class DynamoTritonHOPifier(TritonHOPifier):
grid = grid.call_function(tx, [meta], {})
return grid
# We use this function to wrap call_prune_configs
def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable):
from .builder import SourcelessBuilder
wrapped_user_function = SourcelessBuilder.create(tx, user_fn)
result = wrapped_user_function.call_function(tx, args, kwargs)
return result
def wrap_user_defined_obj(self, user_obj, tx, variable, name):
from .builder import VariableBuilder
wrapped_user_obj = VariableBuilder(
tx, AttrSource(variable.kernel_source, f"{name}")
)._wrap(user_obj)
return wrapped_user_obj
def maybe_unpack_configs(self, configs, tx):
# unpack the list of configs
configs = configs.unpack_var_sequence(tx)
# guard_as_python_constant inserts guards for Dynamo to check if the configs object changed.
configs = [config.guard_as_python_constant() for config in configs]
return configs
# We need to override call_getitem here so that we can add the source in the case
# where we call the triton kernel with a grid
def call_getitem(
self,
variable: "TritonKernelVariable",
args: Sequence[Any],
) -> "TritonKernelVariable":
# __getitem__ should only be called if we don't already have a grid
# Only grid needs to be passed
if variable.grid is not None or len(args) != 1:
self.raise_unsupported(
"Triton kernels should be called with only a single grid"
)
return type(variable)(
kernel=variable.kernel,
kernel_idx=variable.kernel_idx,
grid=args[0],
kernel_source=variable.source,
)
def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
from .constant import ConstantVariable
from .dicts import ConstDictVariable
@ -1172,8 +1228,10 @@ class TritonKernelVariable(VariableTracker):
grid: "TritonGridType"
kernel: "TritonKernelType"
kernel_idx: Optional[int]
kernel_source: "AttrSource"
def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
self.kernel_source = kwargs.pop("kernel_source", None)
super().__init__(**kwargs)
dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)

View File

@ -54,7 +54,7 @@ if TYPE_CHECKING:
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
from triton.runtime.jit import JITFunction
else:
@ -65,7 +65,8 @@ if TYPE_CHECKING:
pass
TritonKernelType = Union[Autotuner, JITFunction]
# mypy specifically complains that TritonAutotunerType is not a valid type if Autotuner is not inside of a Union.
TritonAutotunerType = Union[Autotuner]
log = logging.getLogger("torch._dynamo")
@ -719,7 +720,6 @@ def triton_kernel_wrapper_mutation_dense(
*block_dims,
element_size,
)
# move as many positional arguments from dicts to args as we
# can to circumvent the bug with the kwargs and pre_/post_hook:
# https://github.com/triton-lang/triton/issues/5082
@ -1027,6 +1027,79 @@ class TritonHOPifier:
) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]:
raise NotImplementedError("abstract method")
def wrap_user_defined_obj(
self,
user_obj: Any,
tx: Optional["InstructionTranslator"],
variable: Optional[
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
],
name: str,
) -> Any:
raise NotImplementedError("abstract method")
def call_user_defined_fn(
self,
user_fn: Callable[..., Any],
args: List,
kwargs: Dict,
tx: Optional["InstructionTranslator"],
variable: Optional[
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
],
) -> Any:
raise NotImplementedError("abstract method")
def maybe_unpack_configs(
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
) -> List["TritonConfig"]:
raise NotImplementedError("abstract method")
@staticmethod
def do_prune_configs( # type: ignore[no-untyped-def]
autotuner: "TritonAutotunerType",
early_config_prune: Optional[Callable],
perf_model: Optional[Callable],
top_k: float,
configs: List,
named_args: Dict,
kwargs: Dict,
) -> List["TritonConfig"]:
# Reimplement autotuner.prune_configs(...) here
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
# We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model
# These are both user-defined functions which can contain side effects, so we want to sandbox them in Dynamo
if early_config_prune:
configs = early_config_prune(configs, named_args, **kwargs)
if perf_model:
# we assert top_k is a float before calling this
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(configs) * top_k)
elif not isinstance(top_k, int):
"""
Slice index must be an integer, SupportsIndex or None
"""
raise TypeError(
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
)
if len(configs) > top_k:
est_timing = [
(
config,
float(
perf_model(**named_args, **kwargs, **config.all_kwargs())
),
)
for config in configs
]
configs = [
config[0]
for config in sorted(est_timing, key=lambda x: x[1])[:top_k]
]
return configs
def call_HOP( # type: ignore[no-untyped-def]
self,
variable,
@ -1083,11 +1156,6 @@ class TritonHOPifier:
and defaults["rep"].default
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
or (
"use_cuda_graph" in defaults
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
@ -1171,17 +1239,21 @@ class TritonHOPifier:
from triton import JITFunction
from triton.runtime.autotuner import autotune, Autotuner, Config, Heuristics
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
# Check if num_ctas is in kwargs
if "num_ctas" in kwargs:
self.raise_unsupported(
"Passing num_ctas directly to the Triton kernel is not supported. "
"Please use a Config in @triton.autotune instead."
)
# Currently, if there are multiple autotuning decorators, the subsequent ones will be silently ignored.
# We also don't support the @triton.heuristics wrapper yet.
# We raise an error here to avoid silent incorrectness in these cases
# Make sure the kernel has a grid
if variable.grid is None:
self.raise_unsupported("Triton kernels should always be called with a grid")
"""
We also don't support the @triton.heuristics wrapper yet.
We raise an error here to avoid silent incorrectness in these cases
"""
iter_kernel = variable.kernel
autotuner_count = 0
while not isinstance(iter_kernel, JITFunction):
@ -1189,7 +1261,7 @@ class TritonHOPifier:
autotuner_count += 1
if isinstance(iter_kernel, Heuristics):
self.raise_unsupported(
"Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported. "
"Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported."
)
if autotuner_count > 1:
self.raise_unsupported(
@ -1198,6 +1270,9 @@ class TritonHOPifier:
)
iter_kernel = iter_kernel.fn
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
# move special config names to configs out of kwargs
special_kwargs = {}
for name in SPECIAL_CONFIG_NAMES:
if name in kwargs:
@ -1212,11 +1287,20 @@ class TritonHOPifier:
new_configs = copy.deepcopy(variable.kernel.configs)
for config in new_configs:
config.__dict__.update(special_kwargs)
new_kernel = autotune(configs=new_configs, key=[])(variable.kernel.fn)
prune_configs_by = {
"perf_model": variable.kernel.perf_model,
"early_config_prune": variable.kernel.early_config_prune,
"configs_top_k": variable.kernel.configs_top_k,
}
new_kernel = autotune(
configs=new_configs, key=[], prune_configs_by=prune_configs_by
)(variable.kernel.fn)
else:
# if there is no Autotuner, wrap the kernel into a
# new one with a single config with special kwargs
new_config = Config(kwargs={}, **special_kwargs)
new_kernel = autotune(configs=[new_config], key=[])(variable.kernel)
# create a new variable to contain the new (wrapped) kernel;
@ -1250,19 +1334,83 @@ class TritonHOPifier:
updated = True
if updated:
new_kernel = autotune(configs=new_configs, key=[])(
variable.kernel.fn
)
prune_configs_by = {
"perf_model": variable.kernel.perf_model,
"early_config_prune": variable.kernel.early_config_prune,
"configs_top_k": variable.kernel.configs_top_k,
}
new_kernel = autotune(
configs=new_configs, prune_configs_by=prune_configs_by, key=[]
)(variable.kernel.fn)
new_var = type(variable)(new_kernel, None, variable.grid)
return self.call_triton_kernel(new_var, args, kwargs, tx)
if variable.grid is None:
self.raise_unsupported("Triton kernels should always be called with a grid")
# These are the default values in upstream Triton
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
default_perf_model = None
default_early_config_prune = None
# run prune_configs_by
if isinstance(variable.kernel, Autotuner) and (
variable.kernel.perf_model != default_perf_model
or variable.kernel.early_config_prune != default_early_config_prune
):
# Prune the configs
named_args = dict(zip(variable.kernel.arg_names, args))
# The source information is important here so the guards are installed correctly
wrapped_early_configs_prune = self.wrap_user_defined_obj(
variable.kernel.early_config_prune,
tx,
variable,
"early_config_prune",
)
wrapped_perf_model = self.wrap_user_defined_obj(
variable.kernel.perf_model, tx, variable, "perf_model"
)
wrapped_configs_top_k = self.wrap_user_defined_obj(
variable.kernel.configs_top_k, tx, variable, "configs_top_k"
)
wrapped_configs = self.wrap_user_defined_obj(
variable.kernel.configs, tx, variable, "configs"
)
pruned_configs = self.call_user_defined_fn(
self.do_prune_configs,
[
variable,
wrapped_early_configs_prune,
wrapped_perf_model,
wrapped_configs_top_k,
wrapped_configs,
named_args,
kwargs,
],
{},
tx,
variable,
)
pruned_configs = self.maybe_unpack_configs(pruned_configs, tx)
# after pruning the configs, create a new autotuner object with
# these configs and recurse.
new_kernel = autotune(configs=pruned_configs, key=[])(variable.kernel.fn)
# create a new variable to contain the new (wrapped) kernel;
# skip kernel_idx to get a new record in the kernel side table
new_var = type(variable)(new_kernel, None, variable.grid)
return self.call_triton_kernel(new_var, args, kwargs, tx)
# Both for grid's meta as well as for the kernel, we need combined
# args and kwargs combined and normalized
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
# precompute the grid for the kernel
configs = (
[config.kwargs for config in variable.kernel.configs]
if isinstance(variable.kernel, Autotuner)
@ -1295,6 +1443,8 @@ class TritonHOPifier:
if isinstance(variable.kernel, JITFunction):
constexprs = variable.kernel.constexprs
else:
# If we are looking at an @triton.autotune decorator, the nested function should be a JITFunction
# This is because we don't support @triton.heuristics or nested @triton.autotune decorators yet
assert isinstance(variable.kernel, Autotuner)
constexprs = variable.kernel.fn.constexprs
@ -1343,6 +1493,39 @@ class TracingTritonHOPifier(TritonHOPifier):
assert callable(grid)
return grid(meta)
def wrap_user_defined_obj(
self,
user_obj: Any,
tx: Optional["InstructionTranslator"],
variable: Optional[
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
],
name: str,
) -> Any:
assert tx is None
return user_obj
def call_user_defined_fn(
self,
user_fn: Callable[..., Any],
args: List,
kwargs: Dict,
tx: Optional["InstructionTranslator"],
variable: Optional[
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
],
) -> Any:
assert isinstance(args, list)
assert isinstance(kwargs, dict)
assert callable(user_fn)
return user_fn(*args, **kwargs)
def maybe_unpack_configs(
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
) -> List["TritonConfig"]:
assert isinstance(configs, list)
return configs
def check_grid(
self,
grid: "TritonGridType",