mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[user triton] add support for prune_configs_by in @triton.autotune (#142207)
This PR adds support for prune_configs_by in the @triton.autotune decorator [docs](https://triton-lang.org/main/python-api/generated/triton.autotune.html#triton.autotune). Supporting this lets users reduce autotuning time by running user-supplied code (early_config_prune, perf_model) to prune the provided list of configs. We implement this by realizing args/kwargs in call_triton_kernel(...), and then calling kernel.prune_configs(...). Pull Request resolved: https://github.com/pytorch/pytorch/pull/142207 Approved by: https://github.com/zou3519, https://github.com/aakhundov
This commit is contained in:
parent
479d6f2199
commit
ec1f56fdcf
|
|
@ -1920,14 +1920,16 @@ def forward(self, arg0_1, arg1_1):
|
|||
def kernel(X):
|
||||
return
|
||||
|
||||
@torch.compile(backend=backend)
|
||||
@torch.compile(fullgraph=True, backend=backend)
|
||||
def f(x):
|
||||
kernel[(1,)](x, num_ctas=1)
|
||||
kernel.run(x, num_ctas=1, grid=(1,), warmup=False)
|
||||
return x
|
||||
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
f(x)
|
||||
msg = "Passing num_ctas directly to the Triton kernel is not supported. Please use a Config in @triton.autotune instead."
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
f(x)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
|
|
@ -3649,13 +3651,198 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
|||
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||
|
||||
# this should cause an exception, since pre_hook is not allowed
|
||||
msg = "Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported. "
|
||||
msg = "Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported."
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
add_compiled = torch.compile(
|
||||
add, mode="reduce-overhead", fullgraph=True, backend=backend
|
||||
)
|
||||
add_compiled(x, y).mean()
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("non_strict", [True, False])
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
@common_utils.parametrize("with_perf_model", [True, False])
|
||||
def test_triton_kernel_prune_configs_by(self, backend, with_perf_model, non_strict):
|
||||
# for non-strict mode
|
||||
libname = "my_cool_namespace"
|
||||
opname = "my_triton_operator"
|
||||
|
||||
records = {}
|
||||
|
||||
def early_config_prune(configs, named_args, **kwargs):
|
||||
# we need to save the records to the returned config
|
||||
records["run_early_config_prune"] = True
|
||||
if "N" in kwargs and kwargs["N"] == 1024:
|
||||
records["capture_kwargs"] = True
|
||||
# named args are: dst, src, add_float
|
||||
if "dst" in named_args and "src" in named_args and len(named_args) == 3:
|
||||
records["capture_named_args"] = True
|
||||
return [configs[0]]
|
||||
|
||||
def perf_model(*args, **kwargs):
|
||||
records["run_perf_model"] = True
|
||||
return kwargs["BLOCK_SIZE"] * -1
|
||||
|
||||
if with_perf_model:
|
||||
prune_configs_by = {"perf_model": perf_model, "top_k": 1}
|
||||
else:
|
||||
prune_configs_by = {"early_config_prune": early_config_prune}
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 32}),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}),
|
||||
],
|
||||
key=["N"],
|
||||
prune_configs_by=prune_configs_by,
|
||||
)
|
||||
@triton.jit
|
||||
def prune_by_kernel(
|
||||
dst,
|
||||
src,
|
||||
add_float,
|
||||
N,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
# we only modify dst if our perf_model is applied (and a BLOCK_SIZE of 128 is selected)
|
||||
if BLOCK_SIZE == 128:
|
||||
x = x + add_float
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
|
||||
def f(
|
||||
dst: torch.Tensor,
|
||||
src: torch.Tensor,
|
||||
add_float: float,
|
||||
N: int,
|
||||
) -> None:
|
||||
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
||||
if non_strict:
|
||||
torch.library.wrap_triton(prune_by_kernel)[grid](
|
||||
dst, src, add_float, N=N
|
||||
)
|
||||
else:
|
||||
prune_by_kernel[grid](dst, src, add_float, N=N)
|
||||
|
||||
if non_strict:
|
||||
decorator = torch.library.triton_op(
|
||||
f"{libname}::{opname}", mutates_args={"dst"}
|
||||
)(f)
|
||||
else:
|
||||
# we can just pass the function 'f' for dynamo
|
||||
decorator = f
|
||||
|
||||
compiled_f = torch.compile(decorator, backend=backend)
|
||||
N = 1024
|
||||
src = torch.randn(N, device=GPU_TYPE)
|
||||
dst = torch.empty(N, device=GPU_TYPE)
|
||||
compiled_f(dst, src, 1.5, N)
|
||||
|
||||
if with_perf_model:
|
||||
# when applying the perf_model: kwargs["BLOCK_SIZE"] * -1, the largest config (BLOCK_SIZE==128) is selected
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(src + 1.5, dst)
|
||||
else:
|
||||
# without the perf_model, the BLOCK_SIZE==32, and as a result dst is not modified and remains equal to src
|
||||
self.assertEqual(src, dst)
|
||||
self.assertEqual(len(records), 3)
|
||||
self.assertTrue(records["run_early_config_prune"])
|
||||
self.assertTrue(records["capture_kwargs"])
|
||||
self.assertTrue(records["capture_named_args"])
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
@common_utils.parametrize("with_perf_model", [True, False])
|
||||
def test_triton_kernel_prune_configs_by_recompile(self, backend, with_perf_model):
|
||||
"""
|
||||
We want to recompile if anyone changes configs in the autotuner object
|
||||
In short if for example the following sequence of events happens:
|
||||
1. foo = torch.compile(bar)
|
||||
1. call foo
|
||||
2. autotuner.configs = [new configs list]
|
||||
3. call foo
|
||||
|
||||
A recompile event should occur, which we check with Dynamo counters
|
||||
This tests that we are installing guards on input objects properly
|
||||
"""
|
||||
|
||||
# We don't modify records here because we are testing whether or not
|
||||
# recompiles occur/guards are installed
|
||||
# If we modified the non-local records dict here, this would trigger
|
||||
# recompile events.
|
||||
def early_config_prune(configs, named_args, **kwargs):
|
||||
return [configs[0]]
|
||||
|
||||
def perf_model(*args, **kwargs):
|
||||
return kwargs["BLOCK_SIZE"] * -1
|
||||
|
||||
if with_perf_model:
|
||||
prune_configs_by = {"perf_model": perf_model, "top_k": 1}
|
||||
else:
|
||||
prune_configs_by = {"early_config_prune": early_config_prune}
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 32}),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}),
|
||||
],
|
||||
key=["N"],
|
||||
prune_configs_by=prune_configs_by,
|
||||
)
|
||||
@triton.jit
|
||||
def prune_by_kernel(
|
||||
dst,
|
||||
src,
|
||||
add_float,
|
||||
N,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
# Let's make sure we always select a block size of 128 based on our perf_model
|
||||
if BLOCK_SIZE == 128:
|
||||
x = x + add_float
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
|
||||
torch._dynamo.reset()
|
||||
counter = torch._dynamo.testing.CompileCounterWithBackend(backend=backend)
|
||||
|
||||
@torch.compile(fullgraph=True, backend=counter)
|
||||
def f(dst, src, add_float, N):
|
||||
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
|
||||
prune_by_kernel[grid](dst, src, add_float, N=N)
|
||||
|
||||
N = 1024
|
||||
src = torch.randn(N, device=GPU_TYPE)
|
||||
dst = torch.empty(N, device=GPU_TYPE)
|
||||
|
||||
# first compilation, this prunes the configs
|
||||
f(dst, src, 1.5, N)
|
||||
|
||||
self.assertEqual(counter.op_count, 1)
|
||||
|
||||
f(dst, src, 1.5, N)
|
||||
|
||||
# this should not trigger a recompilation
|
||||
# this is because we modified the test to not touch the records dict
|
||||
# as we do in test_triton_kernel_prune_configs_by. If we kept it, it would trigger a recompile here.
|
||||
self.assertEqual(counter.op_count, 1)
|
||||
|
||||
# Modify the autotuner object
|
||||
prune_by_kernel.configs = [triton.Config(kwargs={"BLOCK_SIZE": 64})]
|
||||
|
||||
# Calling the kernel after modifying the autotuner should
|
||||
# trigger a recompile
|
||||
f(dst, src, 1.5, N)
|
||||
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
# there should be no recompile here
|
||||
f(dst, src, 1.5, N)
|
||||
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(KernelTests)
|
||||
common_utils.instantiate_parametrized_tests(CustomOpTests)
|
||||
|
|
|
|||
|
|
@ -1646,7 +1646,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
return self.FUNCTION_MATCH(guard)
|
||||
|
||||
def SEQUENCE_LENGTH(self, guard):
|
||||
# This guard is used to check lenght of PySequence objects like list,
|
||||
# This guard is used to check length of PySequence objects like list,
|
||||
# tuple, collections.deque etc
|
||||
ref = self.arg_ref(guard)
|
||||
value = self.get(guard.name)
|
||||
|
|
|
|||
|
|
@ -294,6 +294,7 @@ manual_torch_name_rule_map = {
|
|||
"torch._functorch.deprecated.grad_and_value": UserFunctionVariable,
|
||||
"torch._functorch.deprecated.vjp": UserFunctionVariable,
|
||||
# everything else
|
||||
"torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable,
|
||||
"torch._higher_order_ops.foreach_map.foreach_map": UserFunctionVariable,
|
||||
"torch._constrain_as_size": UserFunctionVariable,
|
||||
"torch._tensor._convert": UserFunctionVariable,
|
||||
|
|
@ -3298,6 +3299,7 @@ MOD_INLINELIST = [
|
|||
"torch._higher_order_ops.invoke_subgraph",
|
||||
"torch._higher_order_ops.scan",
|
||||
"torch._higher_order_ops.strict_mode",
|
||||
"torch._higher_order_ops.triton_kernel_wrap",
|
||||
"torch._higher_order_ops.while_loop",
|
||||
"torch._inductor.test_operators",
|
||||
"torch._library.autograd",
|
||||
|
|
|
|||
|
|
@ -6,7 +6,17 @@ import functools
|
|||
import inspect
|
||||
import itertools
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from typing_extensions import Never
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -1098,6 +1108,52 @@ class DynamoTritonHOPifier(TritonHOPifier):
|
|||
grid = grid.call_function(tx, [meta], {})
|
||||
return grid
|
||||
|
||||
# We use this function to wrap call_prune_configs
|
||||
def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
wrapped_user_function = SourcelessBuilder.create(tx, user_fn)
|
||||
result = wrapped_user_function.call_function(tx, args, kwargs)
|
||||
return result
|
||||
|
||||
def wrap_user_defined_obj(self, user_obj, tx, variable, name):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
wrapped_user_obj = VariableBuilder(
|
||||
tx, AttrSource(variable.kernel_source, f"{name}")
|
||||
)._wrap(user_obj)
|
||||
return wrapped_user_obj
|
||||
|
||||
def maybe_unpack_configs(self, configs, tx):
|
||||
# unpack the list of configs
|
||||
configs = configs.unpack_var_sequence(tx)
|
||||
|
||||
# guard_as_python_constant inserts guards for Dynamo to check if the configs object changed.
|
||||
configs = [config.guard_as_python_constant() for config in configs]
|
||||
|
||||
return configs
|
||||
|
||||
# We need to override call_getitem here so that we can add the source in the case
|
||||
# where we call the triton kernel with a grid
|
||||
def call_getitem(
|
||||
self,
|
||||
variable: "TritonKernelVariable",
|
||||
args: Sequence[Any],
|
||||
) -> "TritonKernelVariable":
|
||||
# __getitem__ should only be called if we don't already have a grid
|
||||
# Only grid needs to be passed
|
||||
if variable.grid is not None or len(args) != 1:
|
||||
self.raise_unsupported(
|
||||
"Triton kernels should be called with only a single grid"
|
||||
)
|
||||
|
||||
return type(variable)(
|
||||
kernel=variable.kernel,
|
||||
kernel_idx=variable.kernel_idx,
|
||||
grid=args[0],
|
||||
kernel_source=variable.source,
|
||||
)
|
||||
|
||||
def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
|
||||
from .constant import ConstantVariable
|
||||
from .dicts import ConstDictVariable
|
||||
|
|
@ -1172,8 +1228,10 @@ class TritonKernelVariable(VariableTracker):
|
|||
grid: "TritonGridType"
|
||||
kernel: "TritonKernelType"
|
||||
kernel_idx: Optional[int]
|
||||
kernel_source: "AttrSource"
|
||||
|
||||
def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
|
||||
self.kernel_source = kwargs.pop("kernel_source", None)
|
||||
super().__init__(**kwargs)
|
||||
dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
|||
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
|
||||
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
|
||||
from triton.runtime.jit import JITFunction
|
||||
else:
|
||||
|
||||
|
|
@ -65,7 +65,8 @@ if TYPE_CHECKING:
|
|||
pass
|
||||
|
||||
TritonKernelType = Union[Autotuner, JITFunction]
|
||||
|
||||
# mypy specifically complains that TritonAutotunerType is not a valid type if Autotuner is not inside of a Union.
|
||||
TritonAutotunerType = Union[Autotuner]
|
||||
|
||||
log = logging.getLogger("torch._dynamo")
|
||||
|
||||
|
|
@ -719,7 +720,6 @@ def triton_kernel_wrapper_mutation_dense(
|
|||
*block_dims,
|
||||
element_size,
|
||||
)
|
||||
|
||||
# move as many positional arguments from dicts to args as we
|
||||
# can to circumvent the bug with the kwargs and pre_/post_hook:
|
||||
# https://github.com/triton-lang/triton/issues/5082
|
||||
|
|
@ -1027,6 +1027,79 @@ class TritonHOPifier:
|
|||
) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def wrap_user_defined_obj(
|
||||
self,
|
||||
user_obj: Any,
|
||||
tx: Optional["InstructionTranslator"],
|
||||
variable: Optional[
|
||||
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
||||
],
|
||||
name: str,
|
||||
) -> Any:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def call_user_defined_fn(
|
||||
self,
|
||||
user_fn: Callable[..., Any],
|
||||
args: List,
|
||||
kwargs: Dict,
|
||||
tx: Optional["InstructionTranslator"],
|
||||
variable: Optional[
|
||||
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
||||
],
|
||||
) -> Any:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def maybe_unpack_configs(
|
||||
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
|
||||
) -> List["TritonConfig"]:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
@staticmethod
|
||||
def do_prune_configs( # type: ignore[no-untyped-def]
|
||||
autotuner: "TritonAutotunerType",
|
||||
early_config_prune: Optional[Callable],
|
||||
perf_model: Optional[Callable],
|
||||
top_k: float,
|
||||
configs: List,
|
||||
named_args: Dict,
|
||||
kwargs: Dict,
|
||||
) -> List["TritonConfig"]:
|
||||
# Reimplement autotuner.prune_configs(...) here
|
||||
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
|
||||
# We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model
|
||||
# These are both user-defined functions which can contain side effects, so we want to sandbox them in Dynamo
|
||||
|
||||
if early_config_prune:
|
||||
configs = early_config_prune(configs, named_args, **kwargs)
|
||||
|
||||
if perf_model:
|
||||
# we assert top_k is a float before calling this
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(configs) * top_k)
|
||||
elif not isinstance(top_k, int):
|
||||
"""
|
||||
Slice index must be an integer, SupportsIndex or None
|
||||
"""
|
||||
raise TypeError(
|
||||
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
||||
)
|
||||
if len(configs) > top_k:
|
||||
est_timing = [
|
||||
(
|
||||
config,
|
||||
float(
|
||||
perf_model(**named_args, **kwargs, **config.all_kwargs())
|
||||
),
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
configs = [
|
||||
config[0]
|
||||
for config in sorted(est_timing, key=lambda x: x[1])[:top_k]
|
||||
]
|
||||
return configs
|
||||
|
||||
def call_HOP( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
variable,
|
||||
|
|
@ -1083,11 +1156,6 @@ class TritonHOPifier:
|
|||
and defaults["rep"].default
|
||||
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
|
||||
)
|
||||
or (
|
||||
"prune_configs_by" in defaults
|
||||
and defaults["prune_configs_by"].default
|
||||
!= kernel.early_config_prune
|
||||
)
|
||||
or (
|
||||
"use_cuda_graph" in defaults
|
||||
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
|
||||
|
|
@ -1171,17 +1239,21 @@ class TritonHOPifier:
|
|||
from triton import JITFunction
|
||||
from triton.runtime.autotuner import autotune, Autotuner, Config, Heuristics
|
||||
|
||||
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
|
||||
|
||||
# Check if num_ctas is in kwargs
|
||||
if "num_ctas" in kwargs:
|
||||
self.raise_unsupported(
|
||||
"Passing num_ctas directly to the Triton kernel is not supported. "
|
||||
"Please use a Config in @triton.autotune instead."
|
||||
)
|
||||
|
||||
# Currently, if there are multiple autotuning decorators, the subsequent ones will be silently ignored.
|
||||
# We also don't support the @triton.heuristics wrapper yet.
|
||||
# We raise an error here to avoid silent incorrectness in these cases
|
||||
# Make sure the kernel has a grid
|
||||
if variable.grid is None:
|
||||
self.raise_unsupported("Triton kernels should always be called with a grid")
|
||||
|
||||
"""
|
||||
We also don't support the @triton.heuristics wrapper yet.
|
||||
We raise an error here to avoid silent incorrectness in these cases
|
||||
"""
|
||||
iter_kernel = variable.kernel
|
||||
autotuner_count = 0
|
||||
while not isinstance(iter_kernel, JITFunction):
|
||||
|
|
@ -1189,7 +1261,7 @@ class TritonHOPifier:
|
|||
autotuner_count += 1
|
||||
if isinstance(iter_kernel, Heuristics):
|
||||
self.raise_unsupported(
|
||||
"Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported. "
|
||||
"Passing @triton.heuristics decorator after @triton.autotune decorator is not supported. is not supported."
|
||||
)
|
||||
if autotuner_count > 1:
|
||||
self.raise_unsupported(
|
||||
|
|
@ -1198,6 +1270,9 @@ class TritonHOPifier:
|
|||
)
|
||||
iter_kernel = iter_kernel.fn
|
||||
|
||||
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
|
||||
|
||||
# move special config names to configs out of kwargs
|
||||
special_kwargs = {}
|
||||
for name in SPECIAL_CONFIG_NAMES:
|
||||
if name in kwargs:
|
||||
|
|
@ -1212,11 +1287,20 @@ class TritonHOPifier:
|
|||
new_configs = copy.deepcopy(variable.kernel.configs)
|
||||
for config in new_configs:
|
||||
config.__dict__.update(special_kwargs)
|
||||
new_kernel = autotune(configs=new_configs, key=[])(variable.kernel.fn)
|
||||
prune_configs_by = {
|
||||
"perf_model": variable.kernel.perf_model,
|
||||
"early_config_prune": variable.kernel.early_config_prune,
|
||||
"configs_top_k": variable.kernel.configs_top_k,
|
||||
}
|
||||
|
||||
new_kernel = autotune(
|
||||
configs=new_configs, key=[], prune_configs_by=prune_configs_by
|
||||
)(variable.kernel.fn)
|
||||
else:
|
||||
# if there is no Autotuner, wrap the kernel into a
|
||||
# new one with a single config with special kwargs
|
||||
new_config = Config(kwargs={}, **special_kwargs)
|
||||
|
||||
new_kernel = autotune(configs=[new_config], key=[])(variable.kernel)
|
||||
|
||||
# create a new variable to contain the new (wrapped) kernel;
|
||||
|
|
@ -1250,19 +1334,83 @@ class TritonHOPifier:
|
|||
updated = True
|
||||
|
||||
if updated:
|
||||
new_kernel = autotune(configs=new_configs, key=[])(
|
||||
variable.kernel.fn
|
||||
)
|
||||
prune_configs_by = {
|
||||
"perf_model": variable.kernel.perf_model,
|
||||
"early_config_prune": variable.kernel.early_config_prune,
|
||||
"configs_top_k": variable.kernel.configs_top_k,
|
||||
}
|
||||
|
||||
new_kernel = autotune(
|
||||
configs=new_configs, prune_configs_by=prune_configs_by, key=[]
|
||||
)(variable.kernel.fn)
|
||||
new_var = type(variable)(new_kernel, None, variable.grid)
|
||||
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
||||
|
||||
if variable.grid is None:
|
||||
self.raise_unsupported("Triton kernels should always be called with a grid")
|
||||
# These are the default values in upstream Triton
|
||||
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
|
||||
default_perf_model = None
|
||||
default_early_config_prune = None
|
||||
|
||||
# run prune_configs_by
|
||||
if isinstance(variable.kernel, Autotuner) and (
|
||||
variable.kernel.perf_model != default_perf_model
|
||||
or variable.kernel.early_config_prune != default_early_config_prune
|
||||
):
|
||||
# Prune the configs
|
||||
named_args = dict(zip(variable.kernel.arg_names, args))
|
||||
|
||||
# The source information is important here so the guards are installed correctly
|
||||
|
||||
wrapped_early_configs_prune = self.wrap_user_defined_obj(
|
||||
variable.kernel.early_config_prune,
|
||||
tx,
|
||||
variable,
|
||||
"early_config_prune",
|
||||
)
|
||||
|
||||
wrapped_perf_model = self.wrap_user_defined_obj(
|
||||
variable.kernel.perf_model, tx, variable, "perf_model"
|
||||
)
|
||||
|
||||
wrapped_configs_top_k = self.wrap_user_defined_obj(
|
||||
variable.kernel.configs_top_k, tx, variable, "configs_top_k"
|
||||
)
|
||||
|
||||
wrapped_configs = self.wrap_user_defined_obj(
|
||||
variable.kernel.configs, tx, variable, "configs"
|
||||
)
|
||||
|
||||
pruned_configs = self.call_user_defined_fn(
|
||||
self.do_prune_configs,
|
||||
[
|
||||
variable,
|
||||
wrapped_early_configs_prune,
|
||||
wrapped_perf_model,
|
||||
wrapped_configs_top_k,
|
||||
wrapped_configs,
|
||||
named_args,
|
||||
kwargs,
|
||||
],
|
||||
{},
|
||||
tx,
|
||||
variable,
|
||||
)
|
||||
|
||||
pruned_configs = self.maybe_unpack_configs(pruned_configs, tx)
|
||||
|
||||
# after pruning the configs, create a new autotuner object with
|
||||
# these configs and recurse.
|
||||
new_kernel = autotune(configs=pruned_configs, key=[])(variable.kernel.fn)
|
||||
# create a new variable to contain the new (wrapped) kernel;
|
||||
# skip kernel_idx to get a new record in the kernel side table
|
||||
new_var = type(variable)(new_kernel, None, variable.grid)
|
||||
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
||||
|
||||
# Both for grid's meta as well as for the kernel, we need combined
|
||||
# args and kwargs combined and normalized
|
||||
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
|
||||
|
||||
# precompute the grid for the kernel
|
||||
configs = (
|
||||
[config.kwargs for config in variable.kernel.configs]
|
||||
if isinstance(variable.kernel, Autotuner)
|
||||
|
|
@ -1295,6 +1443,8 @@ class TritonHOPifier:
|
|||
if isinstance(variable.kernel, JITFunction):
|
||||
constexprs = variable.kernel.constexprs
|
||||
else:
|
||||
# If we are looking at an @triton.autotune decorator, the nested function should be a JITFunction
|
||||
# This is because we don't support @triton.heuristics or nested @triton.autotune decorators yet
|
||||
assert isinstance(variable.kernel, Autotuner)
|
||||
constexprs = variable.kernel.fn.constexprs
|
||||
|
||||
|
|
@ -1343,6 +1493,39 @@ class TracingTritonHOPifier(TritonHOPifier):
|
|||
assert callable(grid)
|
||||
return grid(meta)
|
||||
|
||||
def wrap_user_defined_obj(
|
||||
self,
|
||||
user_obj: Any,
|
||||
tx: Optional["InstructionTranslator"],
|
||||
variable: Optional[
|
||||
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
||||
],
|
||||
name: str,
|
||||
) -> Any:
|
||||
assert tx is None
|
||||
return user_obj
|
||||
|
||||
def call_user_defined_fn(
|
||||
self,
|
||||
user_fn: Callable[..., Any],
|
||||
args: List,
|
||||
kwargs: Dict,
|
||||
tx: Optional["InstructionTranslator"],
|
||||
variable: Optional[
|
||||
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
||||
],
|
||||
) -> Any:
|
||||
assert isinstance(args, list)
|
||||
assert isinstance(kwargs, dict)
|
||||
assert callable(user_fn)
|
||||
return user_fn(*args, **kwargs)
|
||||
|
||||
def maybe_unpack_configs(
|
||||
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
|
||||
) -> List["TritonConfig"]:
|
||||
assert isinstance(configs, list)
|
||||
return configs
|
||||
|
||||
def check_grid(
|
||||
self,
|
||||
grid: "TritonGridType",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user