[import][inductor] Simplify grid handling (#147583)

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Note the attached diff contains some minor fbcode-only changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147583
Approved by: https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
Jason Ansel 2025-03-02 07:31:07 +00:00 committed by PyTorch MergeBot
parent dae3fbfe97
commit b59776d857
42 changed files with 970 additions and 1319 deletions

View File

@ -4086,10 +4086,9 @@ class AOTInductorTestsTemplate:
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
expected_scalar_args = [
"int64_t var_1 = u0;",
"int64_t var_4 = u0;",
"int64_t var_7 = u0;",
"int64_t var_12 = u0;",
"buf3, u0",
"buf4, u0",
"buf3, buf4, buf2, u0",
]
# check the new behavior of codegen is expected
result, code = run_and_get_cpp_code(

View File

@ -54,22 +54,6 @@ class TestCppWrapperHipify(TestCase):
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline hipFunction_t loadKernel(
std::string filePath,
const std::string &funcName,

View File

@ -550,7 +550,7 @@ class CudaReproTests(TestCase):
"""
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
from torch._inductor.runtime.hints import AttrsDescriptorWrapper, HeuristicType
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
def autotune(configs, meta):
def decorator(fn):
@ -564,6 +564,7 @@ class CudaReproTests(TestCase):
reset_to_zero_arg_names=[],
optimize_mem=True,
heuristic_type=HeuristicType.POINTWISE,
inductor_meta={"grid_type": "Grid1D"},
)
return decorator
@ -603,8 +604,8 @@ class CudaReproTests(TestCase):
inout2 = inout1.clone()
stream0 = get_cuda_stream(0)
kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)
kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)
kernel.run(inout1, in0, xnumel, stream=stream0)
kernel.run(inout2, in0, xnumel, stream=stream0)
assert same(
inout1, inout2, tol=0.001, equal_nan=True

View File

@ -59,11 +59,16 @@ class TestKernelBenchmark(TestCase):
def verify_compiled_kernels(self, GB_count=1):
compiled_module = self.get_compiled_module()
# now run the compiled module in subprocess and check its output
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -kc".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
).decode()
try:
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -kc".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
).decode()
except subprocess.CalledProcessError as e:
print("Failed when running output code", e)
print(e.output.decode())
raise e
# make sure we have the bandwidth information in the output
FileCheck().check_count(
@ -112,11 +117,16 @@ class TestKernelBenchmark(TestCase):
def check_bandwidth(self, compiled_module, num_gb):
# now run the compiled module in subprocess and check its output
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -k".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
).decode()
try:
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -k".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
).decode()
except subprocess.CalledProcessError as e:
print("Failed when running output code", e)
print(e.output.decode())
raise e
# make sure we have the bandwidth information in the output
FileCheck().check_count(
@ -156,7 +166,7 @@ class TestKernelBenchmark(TestCase):
self.verify_compiled_kernels()
@config.patch(
max_autotune=True, max_autotune_gemm_backends="TRITON", force_shape_pad=True
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
)
@fresh_inductor_cache()
def test_mm_triton_kernel_benchmark(self):
@ -175,28 +185,7 @@ class TestKernelBenchmark(TestCase):
f(a, b)
GB_count = 3
# pad_mm is not enabled on XPU, so there is only one kernel.
if GPU_TYPE == "xpu":
GB_count = 1
self.verify_compiled_kernels(GB_count=GB_count)
# make sure we correctly generate the grid info
compiled_module = self.get_compiled_module()
with open(compiled_module.__file__) as f:
source_code = f.read()
lines = source_code.split("\n")
meta = [l for l in lines if "meta0 = {" in l]
scope = {}
from torch._inductor.kernel.mm_common import mm_grid
exec(meta[0], scope)
grid = mm_grid(M, N, scope["meta0"])
FileCheck().check_count(
f"grid={grid}",
2,
exactly=1,
).run(source_code)
self.verify_compiled_kernels(GB_count=1)
def test_matmul_bandwidth_computation(self):
"""

View File

@ -53,7 +53,7 @@ def _get_func_call() -> str:
def _get_kernel_launch() -> str:
return "launchKernel(" if config.cpp_wrapper else ".run("
return "call_triton_" if config.cpp_wrapper else ".run("
def benchmark_choice(choice, args, out, expected_out, timings):

View File

@ -265,9 +265,6 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
self.assertEqual(args["kernel_backend"], "triton", msg=f"event = {e}")
self.assertTrue("stream" in args, msg=f"event = {e}")
self.assertTrue("grid" in args, msg=f"event = {e}")
self.assertTrue(args["grid"].startswith("grid"), msg=f"event = {e}")
self.assertTrue("kernel_file" in args, msg=f"event = {e}")
kernel_file = args["kernel_file"]
self.assertTrue(os.path.isfile(kernel_file), msg=f"event = {e}")

View File

@ -353,7 +353,6 @@ class TestSelectAlgorithm(TestCase):
module_path=module_path,
module_cache_key=None,
kernel_name=None,
grid=None,
extra_args=None,
num_stages=None,
num_warps=None,

View File

@ -3985,7 +3985,7 @@ class CommonTemplate:
_, code = run_and_get_code(foo, grouped_conv, input_tensor)
# no to channels last permuting before kernel
if config.cpp_wrapper:
FileCheck().check_not("launchKernel(triton").check("_convolution(").run(
FileCheck().check_not(" call_triton").check("_convolution(").run(
code[0]
)
else:

View File

@ -3635,8 +3635,10 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
output = "\n".join(record.getMessage() for record in log.records)
# correct grid example values updated per block size
FileCheck().check("Compile-time auto-tuning block:").check(
"grid_wrapper_for_op_zeros_0"
).check_next("return (256").check_next("return (64").run(output)
"PrecomputedGrid"
).check("(31 + _launcher_s0) // 32").check("(127 + _launcher_s0) // 128").run(
output
)
# Triton 3.2.0 adds the required flags to the Autotuner object for this test
# PR: https://github.com/triton-lang/triton/pull/5092

View File

@ -639,7 +639,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
extra_args: Iterable[Any],
module_path: str, # the path of the module defining the triton kernel
module_cache_key: str,
grid: list[int],
num_stages: int,
num_warps: int,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
@ -648,7 +647,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.module_path = module_path
self.module_cache_key = module_cache_key
self.grid = grid
self.num_stages = num_stages
self.num_warps = num_warps
self.matrix_instr_nonkdim = matrix_instr_nonkdim
@ -700,16 +698,15 @@ class TritonBenchmarkRequest(BenchmarkRequest):
)
# Handle zero initialization if needed
if workspace_arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL:
if workspace_arg.zero_mode != WorkspaceZeroMode.UNINITIALIZED:
workspace_tensor.zero_()
# Run the kernel with workspace
run_method(
*input_tensors,
output_tensor,
*extra_args,
workspace_tensor,
grid=self.grid,
*extra_args,
**warmup_arg,
stream=stream,
benchmark_run=True,
@ -725,7 +722,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
*input_tensors,
output_tensor,
*extra_args,
grid=self.grid,
**warmup_arg,
stream=stream,
)
@ -735,7 +731,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
*input_tensors,
output_tensor,
*extra_args,
grid=self.grid,
**warmup_arg,
stream=stream,
benchmark_run=True,

View File

@ -1358,7 +1358,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
@clear_on_fresh_inductor_cache
class CudaKernelParamCache:
cache: dict[str, dict[str, str]] = {}
cache: dict[str, dict[str, Any]] = {}
cache_clear = staticmethod(cache.clear)
@classmethod
@ -1376,7 +1376,7 @@ class CudaKernelParamCache:
cls.cache[key] = params
@classmethod
def get(cls, key: str) -> Optional[dict[str, str]]:
def get(cls, key: str) -> Optional[dict[str, Any]]:
return cls.cache.get(key, None)
@classmethod

View File

@ -1550,7 +1550,7 @@ class KernelArgs:
def python_argdefs(
self,
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[torch.dtype]]:
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
arg_defs: list[ArgName] = []
call_args: list[str] = []
arg_types: list[torch.dtype] = []

View File

@ -5204,7 +5204,7 @@ class KernelGroup:
def call_kernel(self, wrapper, kernel_name):
_, call_args, arg_types = self.args.cpp_argdefs()
wrapper.generate_kernel_call(
kernel_name, call_args, gpu=False, triton=False, arg_types=arg_types
kernel_name, call_args, triton=False, arg_types=arg_types
)

View File

@ -118,9 +118,7 @@ class CppTemplateKernel(CppKernel):
def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
wrapper = V.graph.wrapper_code
_, call_args, arg_types = self.args.cpp_argdefs()
wrapper.generate_kernel_call(
name, call_args, triton=False, gpu=False, arg_types=arg_types
)
wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types)
def dtype(self, node: ir.Buffer) -> str:
return DTYPE_TO_CPP[node.get_dtype()]

View File

@ -24,7 +24,6 @@ from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides, IndentedBuffer, Kernel
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
from .triton_utils import should_unwrap_unspec_arg
from .wrapper import (
EnterSubgraphLine,
ExitSubgraphLine,
@ -89,27 +88,20 @@ class CppWrapperCpu(PythonWrapperCodegen):
self,
kernel_name: str,
call_args,
grid=None,
device_index=None,
gpu=False,
triton=False,
*,
device=None,
triton=True,
arg_types=None,
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
"""
Generates kernel call code.
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
triton: Defines whether the GPU backend uses Triton for codegen.
Otherwise it uses the CUDA language for codegen.
Only valid when cuda == True.
"""
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
assert arg_types is not None and len(call_args) == len(arg_types), (
"Mismatch call_args and arg_types in generate_kernel_call"
)
@ -853,27 +845,32 @@ class CppWrapperCpu(PythonWrapperCodegen):
def generate(self, is_inference):
with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True):
if V.graph.aot_mode and not V.graph.is_const_graph:
self.codegen_model_kernels()
self.codegen_model_constructor()
self.codegen_const_run_driver()
self.write_wrapper_decl()
return super().generate(is_inference)
def finalize_prefix(self):
cached_dtypes_buffer = IndentedBuffer()
prior = self.prefix
self.prefix = aot_mode_decls = IndentedBuffer()
if V.graph.aot_mode and not V.graph.is_const_graph:
aot_mode_decls.writeline("namespace torch::aot_inductor {")
self.codegen_model_kernels()
self.codegen_model_constructor()
self.codegen_const_run_driver()
aot_mode_decls.writeline("} // namespace torch::aot_inductor")
aot_mode_decls.writeline("using namespace torch::aot_inductor;")
self.prefix = cache_decls = IndentedBuffer()
for dtype in self.used_cached_dtypes:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
cache_decls.writeline(f"CACHE_TORCH_DTYPE({dtype});")
for device in self.used_cached_devices:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
cache_decls.writeline(f"CACHE_TORCH_DEVICE({device});")
for layout in self.used_cached_layouts:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
cache_decls.writeline(f"CACHE_TORCH_LAYOUT({layout});")
for memory_format in self.used_cached_memory_formats:
cached_dtypes_buffer.writeline(
f"CACHE_TORCH_MEMORY_FORMAT({memory_format});"
)
cached_dtypes_buffer.splice(self.prefix)
self.prefix = cached_dtypes_buffer
cache_decls.writeline(f"CACHE_TORCH_MEMORY_FORMAT({memory_format});")
self.prefix.splice(aot_mode_decls)
self.prefix.splice(prior)
def define_kernel(
self,
@ -1264,24 +1261,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
# it suffices as a type hint for the purposes of producing the correct code for this type.
return SymbolicCallArg(expr, tree.numel)
def prepare_triton_kernel_call(self, device_index, call_args):
def wrap_arg(arg):
if isinstance(arg, str):
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg
elif isinstance(arg, (int, float, bool, SymbolicCallArg)):
return str(arg)
else:
return cexpr(V.graph.sizevars.simplify(arg))
call_args = [wrap_arg(arg) for arg in call_args]
if device_index is None:
current_device = V.graph.get_current_device_or_throw()
device_index = current_device.index
return device_index, call_args
def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")

View File

@ -106,27 +106,21 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
self,
kernel_name: str,
call_args,
grid=None,
device_index=None,
gpu=False,
triton=False,
*,
device=None,
triton=True,
arg_types=None,
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
"""
Generates kernel call code.
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
triton: Defines whether the GPU backend uses Triton for codegen.
Otherwise it uses the CUDA language for codegen.
Only valid when cuda == True.
"""
assert not gpu, (
assert not triton, (
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
)
assert arg_types is not None and len(call_args) == len(arg_types), (

View File

@ -1,188 +1,194 @@
# mypy: allow-untyped-defs
import os
from itertools import chain, count, zip_longest
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from __future__ import annotations
import dataclasses
import re
from itertools import count, zip_longest
from typing import Any, Optional, Union
from typing_extensions import Self
import sympy
from torch import dtype as torch_dtype
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
from .. import config
from ..codecache import CudaKernelParamCache
from ..ir import GraphPartitionSignature, IRNode, TensorBox
from ..utils import (
cache_on_self,
DeferredLineBase,
get_gpu_type,
GPU_ALIGN_BYTES,
triton_version_uses_attrs_dict,
)
from ..ir import GraphPartitionSignature, TensorBox
from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides
from .cpp_utils import cexpr
from .cpp_wrapper_cpu import CppWrapperCpu
from .multi_kernel import MultiKernelCall
from .triton_utils import should_unwrap_unspec_arg
from .wrapper import PythonWrapperCodegen, SymbolicCallArg
if TYPE_CHECKING:
from collections.abc import Hashable
from ..graph import GraphLowering
_cpp_string_literal_escapes = {
"\\": "\\\\",
'"': '\\"',
"\n": "\\n",
"\t": "\\t",
"\r": "\\r",
}
_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]')
class DeferredGpuKernelLine(DeferredLineBase):
def cpp_string_literal(s: str) -> str:
escaped = _cpp_string_literal_pattern.sub(
lambda match: _cpp_string_literal_escapes[match.group(0)], s
)
return f'"{escaped}"'
@dataclasses.dataclass
class DeferredTritonCallWrapper:
"""
When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
to be tuned and stored as cubin files, so use a deferred line to backfill those information
to be tuned and stored as cubin files, so use a deferred generating the final wrapper around
the triton kernel until right before the prefix is written.
"""
def __init__(
self,
kernel_name: str,
line_template: str,
keys: tuple[str, ...],
additional_files: list[str],
):
super().__init__(line_template)
assert not isinstance(line_template, DeferredLineBase)
self.additional_files = additional_files
self.kernel_name = kernel_name
self.line_template = line_template
self.keys = keys
wrapper_name: str
kernel_name: str
arg_types: list[Any]
def __call__(self):
def generate(self, wrapper: CppWrapperGpu):
prefix = wrapper.prefix
if self.kernel_name.startswith("multi_kernel_"):
# MultiKernel will select one kernel after running the autotune block
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name)
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
for key in self.keys:
assert key in params, (
f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
assert params, f"CudaKernelParamCache not populated for {self.kernel_name}"
def_args = params["def_args"]
arg_types = self.arg_types
inductor_meta = params["inductor_meta"]
if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types):
# extra_launcher_args should already be in def_args
assert len(def_args) == len(arg_types) - len(
inductor_meta["extra_launcher_args"]
)
arg_types = arg_types + [SymbolicCallArg] * len(
inductor_meta["extra_launcher_args"]
)
if key == get_cpp_wrapper_cubin_path_name():
assert os.path.exists(params[key]), f"{params[key]} does not exist"
self.additional_files.append(params[key])
return self.line_template % tuple(params[key] for key in self.keys)
if not V.graph.aot_mode:
prefix.writeline(
maybe_hipify_code_wrapper(
f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;"
)
)
kernel_var_name = self.kernel_name
else:
kernel_var_name = f"kernels_.{self.kernel_name}"
def _new_line(self, line):
return DeferredGpuKernelLine(
self.kernel_name, line, self.keys, self.additional_files
# tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types
template_types = [
f"typename {name}_type_"
for name, arg_type in zip(def_args, arg_types)
if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg))
]
if V.graph.aot_mode:
template_types.append("typename kernels_type_")
if template_types:
prefix.writeline(f"template <{', '.join(template_types)}>")
prefix.writeline(f"static inline void {self.wrapper_name}(")
with prefix.indent():
assert len(def_args) == len(arg_types), (def_args, arg_types)
for name, arg_type in zip(def_args, arg_types):
if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)):
prefix.writeline(f"const {name}_type_& {name},")
elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)):
prefix.writeline(f"int64_t {name},")
elif arg_type is float:
prefix.writeline(f"float {name},")
elif arg_type is bool:
prefix.writeline(f"bool {name},")
else:
raise ValueError(f"Unexpected arg type {arg_type}")
prefix.writeline(f"{wrapper.device_codegen.cpp_stream_type()} stream_,")
if V.graph.aot_mode:
prefix.writeline("kernels_type_& kernels_,")
prefix.writeline(
"const std::optional<std::string>& cubin_dir_ = std::nullopt"
)
prefix.writeline("){")
with prefix.indent():
self.generate_grid(prefix, inductor_meta, params)
self.generate_load_kernel(prefix, kernel_var_name, params)
self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params)
prefix.writeline("}")
# Ensure the cubin file is included in the package
V.graph.wrapper_code.additional_files.append(
params[get_cpp_wrapper_cubin_path_name()]
)
class DeferredGpuDefaultGrid:
"""
A container for the default grid, which may be used by DeferredGpuGridLine
"""
def __init__(
def generate_grid(
self,
kernel_name: str,
grid,
grid_callable: Optional[Callable[..., Any]] = None,
**grid_extra_kwargs,
prefix: IndentedBuffer,
inductor_meta: dict[str, Any],
params: dict[str, Any],
):
self.kernel_name = kernel_name
self.grid = grid
self.grid_callable = grid_callable
self.grid_extra_kwargs = grid_extra_kwargs
from ..runtime.triton_heuristics import GridExpr
def __iter__(self):
# DeferredGpuDefaultGrid can be passed to the base class, PythonWrapperCodegen,
# to generate the autotune code block, and thus we need this iterator
return iter(self.grid)
def _process_grid(self, grid: Union[list[Any], tuple[Any, ...]]):
if isinstance(grid, (list, tuple)):
return [self._process_grid(e) for e in grid]
else:
return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid
def __call__(self):
if self.kernel_name.startswith("multi_kernel_"):
# MultiKernel will select one kernel after running the autotune block
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
grid = self.grid
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
grid = self._process_grid(grid)
assert self.grid_callable is not None, "grid_callable can't be None"
if not self.grid_extra_kwargs:
grid_fn = self.grid_callable(*grid)
else:
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
params = CudaKernelParamCache.get(self.kernel_name)
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp")
for line in grid.prefix:
prefix.writeline(line)
prefix.splice(
f"""\
uint32_t grid_0 = {grid.x_grid};
uint32_t grid_1 = {grid.y_grid};
uint32_t grid_2 = {grid.z_grid};
"""
)
return grid_fn(params["meta"])
prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;")
def generate_load_kernel(self, prefix, kernel_var_name, params):
prefix.writeline(f"if ({kernel_var_name} == nullptr) {{")
with prefix.indent():
load_kernel_args = [
cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
cpp_string_literal(params["mangled_name"]),
str(params["shared_mem"]),
"cubin_dir_",
]
prefix.writeline(
f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); "
)
prefix.writeline("}")
class DeferredGpuGridLine(DeferredLineBase):
"""
When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
to be tuned and stored as cubin files, so use a deferred line to backfill those information
"""
def __init__(
self,
kernel_name: str,
grid_var: str,
grid,
autotune_configs,
):
super().__init__("")
self.kernel_name = kernel_name
self.grid_var = grid_var
self.grid = grid
self.autotune_configs = autotune_configs
def __call__(self):
if self.kernel_name.startswith("multi_kernel_"):
# MultiKernel will select one kernel after running the autotune block
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name)
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params):
triton_meta = params["triton_meta"]
assert len(self.arg_types) == len(params["def_args"]), (
self.arg_types,
params["def_args"],
)
if self.autotune_configs is not None:
# This indicates the Triton kernel is a user-defined one.
grid = None
if len(self.grid) == 1:
grid = self.grid[0]
else:
for i, c in enumerate(self.autotune_configs):
if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
grid = self.grid[i]
break
assert grid is not None
elif isinstance(self.grid, DeferredGpuDefaultGrid):
grid = self.grid()
else:
grid = self.grid
assert len(grid) != 0, "Grid can't be empty"
grid_args_str = ", ".join(
[cexpr(V.graph.sizevars.simplify(item)) for item in grid]
)
return f" Grid {self.grid_var} = Grid({grid_args_str});"
def _new_line(self, line):
return DeferredGpuGridLine(
self.kernel_name, self.grid_var, self.grid, self.autotune_configs
arg_type_loookup = dict(zip(params["def_args"], self.arg_types))
# difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants
call_args = [
name for name in params["call_args"] if name not in triton_meta["constants"]
]
arg_types = [arg_type_loookup[name] for name in call_args]
arg_signatures = [triton_meta["signature"][name] for name in call_args]
call_args_str = wrapper.generate_args_decl(
prefix, call_args, arg_types, arg_signatures
)
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
launch_kernel_args = [
kernel_var_name,
"grid_0",
"grid_1",
"grid_2",
str(params["num_warps"]),
str(params["shared_mem"]),
"kernel_args_",
"stream_",
]
prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});")
class CppWrapperGpu(CppWrapperCpu):
@ -195,7 +201,7 @@ class CppWrapperGpu(CppWrapperCpu):
self.device_codegen = get_device_op_overrides(self.device)
super().__init__()
self.grid_id = count()
self._load_kernel_cache: dict[Hashable, str] = {}
self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {}
@staticmethod
def create(
@ -293,77 +299,25 @@ class CppWrapperGpu(CppWrapperCpu):
def generate(self, is_inference):
with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
self.prefix.writeline("\n")
if not V.graph.aot_mode:
for kernel in chain(
sorted(self.src_to_kernel.values()),
sorted(
[entry[0] for entry in self.user_defined_kernel_cache.values()]
),
):
self.prefix.writeline(
maybe_hipify_code_wrapper(
f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;"
)
)
self.prefix.writeline("\n")
return super().generate(is_inference)
def generate_user_defined_triton_kernel(
self,
kernel_name: str,
raw_args: list[Any],
grid: list[Any],
configs,
triton_meta,
constexprs,
):
if (
config.triton.autotune_at_compile_time
and kernel_name not in self.kernel_autotune_names
):
# Call PythonWrapperCodegen to create the autotune code block
PythonWrapperCodegen.generate_user_defined_triton_kernel(
self,
kernel_name,
raw_args,
grid,
configs,
triton_meta,
constexprs,
)
if not triton_version_uses_attrs_dict():
# in C++ wrapper, we don't pass constexpr args, as they don't
# get added as parameters to the PTX code compiled from the
# user-defined Triton kernel (only non-constexpr args do)
raw_args = [
raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
]
args = [self.val_to_arg_str(v) for v in raw_args]
arg_types = [
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
for arg in raw_args
]
# Call self.generate_kernel_call to generate the real kernel call in cpp
self.generate_kernel_call(
kernel_name,
args,
arg_types=arg_types,
raw_args=raw_args,
grid=grid,
gpu=True,
triton=True,
triton_meta=triton_meta,
autotune_configs=configs,
)
def finalize_prefix(self):
"""Define the triton kernels now that autotuning is finished"""
old_prefix = self.prefix # new content should go at start of prefix
self.prefix = IndentedBuffer()
super().finalize_prefix()
for kernel in self._triton_call_wrappers.values():
self.prefix.writeline("\n")
kernel.generate(self)
self.prefix.writeline("\n")
self.prefix.splice(old_prefix)
def generate_tma_descriptor(self, desc):
self.write_tma_descriptor_helpers_once()
# generate data pointer for the source tensor
source = self.generate_args_decl(
code=self,
call_args=[self.val_to_arg_str(desc.tensor)],
arg_types=[desc.tensor.get_dtype()],
arg_signatures=[None],
@ -384,36 +338,9 @@ class CppWrapperGpu(CppWrapperCpu):
args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}"
self.writeline(f"{fn}({args});")
def generate_load_kernel_once(
self,
kernel_name: str,
graph: "GraphLowering", # for per-graph caching
def generate_args_decl(
self, code: Union[IndentedBuffer, Self], call_args, arg_types, arg_signatures
):
cache_key = (kernel_name, graph)
if cache_key in self._load_kernel_cache:
return self._load_kernel_cache[cache_key]
kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
self._load_kernel_cache[cache_key] = kernel_var_name
keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem")
self.writeline(f"if ({kernel_var_name} == nullptr) {{")
deferred_gpu_kernel_line = DeferredGpuKernelLine(
kernel_name,
(
" "
+ kernel_var_name
+ ' = loadKernel("%s", "%s", %s, this->cubin_dir_);'
if V.graph.aot_mode
else " " + kernel_var_name + ' = loadKernel("%s", "%s", %s);'
),
keys,
self.additional_files,
)
self.writeline(deferred_gpu_kernel_line)
self.writeline("}")
return kernel_var_name
def generate_args_decl(self, call_args, arg_types, arg_signatures):
new_args: list[str] = []
# Add more cases for other types as needed
@ -427,26 +354,24 @@ class CppWrapperGpu(CppWrapperCpu):
var_name = f"var_{next(self.arg_var_id)}"
# ignore nvTmaDesc, as host-side TMA descriptors need
# to be passed to the compiled Triton kernel by value
if isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc":
if arg.endswith(".item()"):
# Need to declare a scalar in this case
arg = arg[:-7]
self.codegen_tensor_item(
arg_type,
arg,
var_name,
)
else:
device_ptr_type = self.device_codegen.cpp_device_ptr()
self.writeline(
maybe_hipify_code_wrapper(
f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
)
if isinstance(arg_type, UnwrapUnspecArg) and arg_signature != "nvTmaDesc":
self.codegen_tensor_item(
arg_type.dtype,
arg,
var_name,
indented_buffer=code,
)
elif isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc":
device_ptr_type = self.device_codegen.cpp_device_ptr()
code.writeline(
maybe_hipify_code_wrapper(
f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
)
)
elif arg_type in (sympy.Integer, int):
self.writeline(f"int {var_name} = {cexpr(arg)};")
code.writeline(f"int {var_name} = {cexpr(arg)};")
elif arg_type in (sympy.Float, float):
self.writeline(f"float {var_name} = {cexpr(arg)};")
code.writeline(f"float {var_name} = {cexpr(arg)};")
# For symbolic call arguments, examine the arg signatures from triton meta
# to explicitly cast to the right type
# Reason: `auto` can infer unexpected type against kernel input signature.
@ -455,11 +380,11 @@ class CppWrapperGpu(CppWrapperCpu):
and arg_signature is not None
and arg_signature in signature2dtype.keys()
):
self.writeline(
code.writeline(
f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
)
else:
self.writeline(f"auto {var_name} = {cexpr(arg)};")
code.writeline(f"auto {var_name} = {cexpr(arg)};")
new_args.append(f"&{var_name}")
for arg, arg_type, arg_signature in zip_longest(
@ -478,61 +403,34 @@ class CppWrapperGpu(CppWrapperCpu):
return ", ".join(new_args)
def generate_default_grid(
self,
kernel_name: str,
grid_args: list[Any],
gpu: bool = True,
grid_callable: Optional[Callable[..., Any]] = default_grid_fn,
**grid_extra_kwargs,
):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics. Because its computation needs
to read kernel config after autotune, it is done in a deferred way
using DeferredGpuDefaultGrid.
"""
assert gpu, "CppWrapperGpu.generate_default_grid does not support non-GPU"
return DeferredGpuDefaultGrid(
kernel_name, grid_args, grid_callable, **grid_extra_kwargs
)
def generate_kernel_call(
self,
kernel_name: str,
call_args,
grid=None,
device_index=None,
gpu=True,
*,
device=None,
triton=True,
arg_types=None,
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
"""
Override the default value of argument 'gpu' to True here.
generate_kernel_call can still be called with gpu=False because of
a mix of cpu kernels and gpu kernels.
"""
if not gpu:
device = device or V.graph.get_current_device_or_throw()
if device.type == "cpu":
# Even in CppWrapperGpu, we may see cpp kernels
return CppWrapperCpu.generate_kernel_call(
self,
kernel_name,
call_args,
grid,
device_index,
gpu,
triton,
arg_types,
raw_args,
grid_fn,
triton_meta,
autotune_configs,
grid_extra_kwargs,
device=device,
triton=triton,
arg_types=arg_types,
raw_args=raw_args,
triton_meta=triton_meta,
)
if (
@ -545,110 +443,38 @@ class CppWrapperGpu(CppWrapperCpu):
self,
kernel_name,
call_args,
grid,
device_index,
gpu,
triton,
arg_types,
raw_args,
grid_fn,
triton_meta,
autotune_configs,
grid_extra_kwargs,
device=device,
triton=triton,
arg_types=arg_types,
raw_args=raw_args,
triton_meta=triton_meta,
)
if device_index is None:
current_device = V.graph.get_current_device_or_throw()
device_index = current_device.index
stream = (
"stream"
if V.graph.aot_mode
else self.write_get_raw_stream(device_index, V.graph)
else self.write_get_raw_stream(device.index, V.graph)
)
if triton:
device_index, call_args = self.prepare_triton_kernel_call(
device_index, call_args
call_args, arg_types = self.prepare_triton_wrapper_args(
call_args, arg_types
)
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
arg_signatures = []
if (
triton_meta is not None
and triton_meta.get("configs")
and triton_meta.get("signature")
):
if triton_version_uses_attrs_dict():
signatures = triton_meta["signature"]
arg_signatures = [
val for val in signatures.values() if val != "constexpr"
]
call_args = [
call_arg
for call_arg, arg_name in zip(call_args, signatures)
if signatures[arg_name] != "constexpr"
]
arg_types = [
arg_type
for arg_type, arg_name in zip(arg_types, signatures)
if signatures[arg_name] != "constexpr"
]
assert len(call_args) == len(arg_signatures), (
f"len of the following lists do not match: {call_args=} {arg_signatures=}"
)
else:
# args with value 1 are added into equal_to_1 and constants
# in triton_meta (in the Python codegen) which makes them
# inlined in the PTX and compiled CUBIN
equal_to_1 = triton_meta["configs"][0].equal_to_1
call_args = [
arg for i, arg in enumerate(call_args) if i not in equal_to_1
]
arg_types = [
t for i, t in enumerate(arg_types) if i not in equal_to_1
]
# extract the arg signatures from triton_meta
arg_signatures = triton_meta["signature"].values()
arg_signatures = [
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
]
call_args_str = self.generate_args_decl(
call_args, arg_types, arg_signatures
)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
self.writeline(
DeferredGpuGridLine(kernel_name, grid_var, grid, autotune_configs)
)
kernel_var_name = (
f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
)
# add debug printer code for all triton kernel related calls
wrapper_name = f"call_{kernel_name}"
if wrapper_name not in self._triton_call_wrappers:
self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper(
wrapper_name, kernel_name, arg_types
)
call_args.append(stream)
if V.graph.aot_mode:
call_args.append("kernels")
call_args.append("this->cubin_dir_")
debug_printer_manager = V.graph.wrapper_code.debug_printer
debug_printer_manager.set_printer_args(
call_args, kernel_name, arg_types, None
call_args[: len(arg_types)], kernel_name, arg_types, None
)
with debug_printer_manager:
self.writeline(f"if ({grid_var}.is_non_zero()) {{")
self.writeline(
DeferredGpuKernelLine(
kernel_name,
r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
kernel_var_name,
f"{grid_var}.grid_x",
f"{grid_var}.grid_y",
f"{grid_var}.grid_z",
kernel_args_var,
stream,
),
("num_warps", "shared_mem"),
self.additional_files,
),
)
self.writeline("}")
self.writeline(f"{wrapper_name}({', '.join(call_args)});")
else:
casted = []
for arg_type, arg in zip(arg_types, call_args):
@ -659,5 +485,34 @@ class CppWrapperGpu(CppWrapperCpu):
call_args_str = ", ".join(casted)
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")
@staticmethod
def prepare_triton_wrapper_args(
call_args: list[Any], arg_types: list[Any]
) -> tuple[list[Any], list[Any]]:
assert len(call_args) == len(arg_types), (call_args, arg_types)
new_args = []
new_args_types = []
for arg, arg_type in zip(call_args, arg_types):
if isinstance(arg, str):
if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg):
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
arg_type = UnwrapUnspecArg(dtype=arg_type)
new_args.append(arg)
elif isinstance(arg, bool):
new_args.append(str(arg).lower())
elif isinstance(arg, (int, float, SymbolicCallArg)):
new_args.append(str(arg))
else:
new_args.append(cexpr(V.graph.sizevars.simplify(arg)))
new_args_types.append(arg_type)
return new_args, new_args_types
def make_zero_buffer(self, name):
return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));"
@dataclasses.dataclass
class UnwrapUnspecArg:
"""Marker that we need to call .item() on the tensor"""
dtype: torch_dtype

View File

@ -341,7 +341,6 @@ class CUDATemplateKernel(CUDAKernel):
wrapper.generate_kernel_call(
name,
call_args,
gpu=True,
triton=False,
arg_types=arg_types,
)

View File

@ -63,22 +63,6 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,

View File

@ -251,7 +251,7 @@ class DebugPrinterManager:
continue
if V.graph.cpp_wrapper:
if arg_signatures is not None and isinstance(
arg_signatures[i], (torch_dtype)
arg_signatures[i], torch_dtype
):
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
V.graph.wrapper_code.writeline(

View File

@ -1633,7 +1633,7 @@ class HalideKernel(SIMDKernel):
wrapper.generate_kernel_call(
name,
call_args,
gpu=False, # grid/stream is handled internally in halide
device=current_device,
triton=False,
)

View File

@ -569,7 +569,7 @@ class MetalKernel(SIMDKernel):
wrapper.generate_kernel_call(
name,
args,
gpu=False, # TODO: Fix me, MPS does not expose streams now
device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now
triton=False,
)

View File

@ -3,7 +3,6 @@ import functools
import logging
import os
import pathlib
from typing import Any
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.utils._ordered_set import OrderedSet
@ -11,11 +10,6 @@ from torch.utils._ordered_set import OrderedSet
from .. import config
from ..codecache import code_hash, CodeCacheFuture, get_path
from ..runtime.benchmarking import benchmarker
from ..runtime.triton_heuristics import (
cooperative_reduction_grid,
grid,
maybe_cooperative_reduction_grid,
)
from ..utils import cache_on_self, IndentedBuffer
from ..virtualized import V
from .common import TensorArg, WorkspaceArg
@ -196,19 +190,6 @@ class MultiKernel:
kernel.args.workspace_args = workspace_args
return workspace_args
def get_grid_fn(self):
fns = OrderedSet(kernel._get_grid_fn() for kernel in self.kernels)
if len(fns) == 1:
return fns.pop()
elif len(fns) == 2:
assert fns == OrderedSet([cooperative_reduction_grid, grid])
V.graph.wrapper_code.add_import_once(
f"from {maybe_cooperative_reduction_grid.__module__} import maybe_cooperative_reduction_grid"
)
return maybe_cooperative_reduction_grid
else:
raise NotImplementedError(fns)
def call_kernel(self, kernel_name):
"""
Collect the union of arguments from all subkernels as the arguments
@ -222,31 +203,21 @@ class MultiKernel:
assert call_args == other_call_args, (call_args, other_call_args)
assert arg_types == other_arg_types
grid: list[Any] = []
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
# for the second pass of cpp-wrapper codegen, we should call
# the fast kernel directly
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
# numels for all subkernels should be the same. Use kernels[0] here
self.kernels[0].add_numel_to_call_args_and_grid(
kernel_name, call_args, arg_types, grid
)
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
for ws in self.kernels[0].args.workspace_args:
V.graph.wrapper_code.generate_workspace_allocation(ws)
grid_fn = self.get_grid_fn()
grid = V.graph.wrapper_code.generate_default_grid(
kernel_name, grid, grid_callable=grid_fn
)
V.graph.wrapper_code.generate_kernel_call(
kernel_name,
call_args,
grid,
arg_types=arg_types,
grid_fn=grid_fn.__name__,
)
for ws in reversed(self.kernels[0].args.workspace_args):

View File

@ -196,12 +196,9 @@ class ROCmTemplateKernel(ROCmKernel):
kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None")
if V.graph.cpp_wrapper:
arg_types.append("uint8_t*")
current_device = V.graph.get_current_device_or_throw()
wrapper.generate_kernel_call(
name,
kernel_args,
device_index=current_device.index,
gpu=True,
triton=False,
arg_types=arg_types,
)

View File

@ -1549,13 +1549,10 @@ class SIMDScheduling(BaseScheduling):
if config.benchmark_kernel:
num_gb = kernel.estimate_kernel_num_bytes() / 1e9
grid_args = V.graph.sizevars.size_hints(kernel.call_sizes)
assert kernel.meta is not None, "meta is None"
grid = kernel.grid_fn(*grid_args, kernel.meta)
src_code = (
f"{kernel.imports_for_benchmark_kernel()}\n"
f"{src_code}\n"
f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}"
f"{kernel.codegen_kernel_benchmark(num_gb).getvalue()}"
)
if only_gen_src_code:

View File

@ -33,6 +33,7 @@ from .. import config, ir, metrics
from ..async_compile import AsyncCompile
from ..codecache import code_hash, get_path, PyCodeCache
from ..ops_handler import DefaultHandler
from ..runtime import triton_heuristics
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import (
AutotuneHint,
@ -41,10 +42,6 @@ from ..runtime.hints import (
TRITON_MAX_RSPLIT,
)
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
from ..runtime.triton_heuristics import (
cooperative_reduction_grid,
grid as default_grid_fn,
)
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..utils import (
cache_on_self,
@ -87,7 +84,6 @@ from .simd import (
IterationRanges,
IterationRangesEntry,
IterationRangesRoot,
pexpr,
SIMDKernel,
SIMDScheduling,
)
@ -98,6 +94,7 @@ from .triton_utils import (
should_unwrap_unspec_arg,
signature_to_meta,
)
from .wrapper import SymbolicCallArg
if TYPE_CHECKING:
@ -3190,7 +3187,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.post_loop_combine.clear()
self.post_loop_store.clear()
def codegen_kernel_benchmark(self, num_gb, grid=None):
def kernel_benchmark_extra_args(self) -> list[str]:
args = []
if self.need_numel_args():
numel_args: list[sympy.Expr] = []
self.add_numel_to_call_args("", numel_args, [])
for arg in numel_args:
if isinstance(arg, int):
args.append(str(arg))
elif isinstance(arg, SymbolicCallArg):
args.append(str(V.graph.sizevars.size_hint(arg.inner_expr)))
elif isinstance(arg, sympy.Expr):
args.append(str(V.graph.sizevars.size_hint(arg)))
else:
raise ValueError(f"Unsupported numel argument type: {type(arg)}")
return args
def codegen_kernel_benchmark(self, num_gb):
result = IndentedBuffer()
_argdefs, call_args, signature, _ = self.args.python_argdefs()
@ -3231,25 +3244,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
f"Don't find the buffer or const tensor for {arg_name}"
)
var_names.append(var_name)
var_names.extend(self.kernel_benchmark_extra_args())
result.writeline(f"return {', '.join(var_names)},")
result.writelines(["\n", "\n", "def call(args):"])
if grid is None:
grid = []
extra_args = []
extra_args_str = None
for tree in self.active_range_trees():
expr = pexpr(V.graph.sizevars.size_hint(tree.numel))
extra_args.append(expr)
if not tree.is_reduction:
grid.append(expr)
if self.need_numel_args():
extra_args_str = ", ".join(map(str, extra_args)) + ", "
else:
extra_args_str = ""
grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})"
else:
grid_arg = f"grid={grid}"
current_device = V.graph.get_current_device_or_throw()
index = current_device.index
with result.indent():
@ -3261,7 +3259,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
stream_name = f"stream{index}"
result.writeline(f"{stream_name} = get_raw_stream({index})")
result.writeline(
f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})"
)
# benchmark all configs
@ -3273,7 +3271,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
V.graph.device_ops.set_device(index)
) # no-op to ensure context
result.writeline(
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)"
)
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
@ -3301,7 +3299,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
from torch._dynamo.testing import rand_strided
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
)
@ -3488,6 +3485,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
inductor_meta = {
# Triton will not accept an OrderedSet for autotune_hints
"grid_type": self._get_grid_type().__name__,
"autotune_hints": set(self.autotune_hints), # noqa: set_linter
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
"mutated_arg_names": mutated_args,
@ -3639,12 +3637,22 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if tree.prefix == "x" and self.no_x_dim:
code.writeline("XBLOCK: tl.constexpr = 1")
def _get_grid_fn(self):
def _get_grid_type(self) -> type[triton_heuristics.GridExpr]:
n = sum([int(not tree.is_reduction) for tree in self.range_trees])
if self.cooperative_reduction:
return cooperative_reduction_grid
return default_grid_fn
assert n == 1
return triton_heuristics.CooperativeReductionGrid
elif n == 1:
return triton_heuristics.Grid1D
elif n == 2:
if any(map(self.needs_yz_grid_overflow, self.range_trees)):
return triton_heuristics.Grid2DWithYZOverflow
return triton_heuristics.Grid2D
elif n == 3:
return triton_heuristics.Grid3D
raise ValueError(f"Unsupported number of dimensions: {n}")
def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):
def add_numel_to_call_args(self, name, call_args, arg_types):
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
for tree in self.range_trees:
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
@ -3655,31 +3663,21 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if not tree.is_reduction or self.inside_reduction:
call_args.append(expr)
arg_types.append(type(expr))
if tree.grid_dim is not None:
grid.append(expr)
def call_kernel(self, name: str, node: Optional[IRNode] = None):
wrapper = V.graph.wrapper_code
wrapper.write_triton_header_once()
_, call_args, _, arg_types = self.args.python_argdefs()
grid: list[Any] = []
self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
current_device = V.graph.get_current_device_or_throw()
self.add_numel_to_call_args(name, call_args, arg_types)
for ws in self.args.workspace_args:
wrapper.generate_workspace_allocation(ws)
grid_fn = self._get_grid_fn()
grid = wrapper.generate_default_grid(name, grid, grid_callable=grid_fn)
wrapper.generate_kernel_call(
name,
call_args,
grid,
current_device.index,
gpu=current_device.type != "cpu",
triton=True,
arg_types=arg_types,
grid_fn=grid_fn.__name__,
triton_meta=self.triton_meta,
)
@ -3738,12 +3736,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
key = f"tl.program_id({entry.grid_dim})"
# y_grid has a limit, so express it in terms of y and z in case of overflow.
# z grid is only exercised when max_tiles == 3 (off by default).
if (
entry.grid_dim == 1
and not entry.has_zdim
and not self.cooperative_reduction
and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid())
):
if self.needs_yz_grid_overflow(entry):
# For ynumel larger than max_ygrid, we need to use zdim.
# For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z).
# So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset.
@ -3753,6 +3746,14 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
return f"{pid}.to({self.index_dtype})"
return pid
def needs_yz_grid_overflow(self, entry: IterationRangesRoot) -> bool:
return (
entry.grid_dim == 1
and not entry.has_zdim
and not self.cooperative_reduction
and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid())
)
def max_block(self, prefix: str) -> int:
if self.fixed_config:
return self.fixed_config[f"{prefix.upper()}BLOCK"]

View File

@ -2,7 +2,6 @@ import itertools
import logging
import textwrap
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, cast, Optional, Union
@ -14,7 +13,10 @@ from torch.utils._ordered_set import OrderedSet
from .. import config, metrics
from ..runtime.hints import DeviceProperties
from ..runtime.runtime_utils import next_power_of_2
from ..runtime.triton_heuristics import grid_combo_kernels
from ..runtime.triton_heuristics import (
RoundRobinComboKernelGrid,
SequentialComboKernelGrid,
)
from ..scheduler import BaseSchedulerNode
from ..utils import Placeholder, triton_version_uses_attrs_dict
from ..virtualized import V
@ -283,6 +285,8 @@ class ComboKernel(Kernel):
grid(...): codegen the grid size for launching the combo kernel.
"""
grid_expr = SequentialComboKernelGrid
@classmethod
def codegen_pid_range(
cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
@ -321,42 +325,6 @@ class ComboKernel(Kernel):
else:
code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}")
@classmethod
def grid(
cls,
sub_kernel_numels: list[list[int]],
x_blocks_list: list[Union[str, int]],
dynamic_shape: bool,
) -> tuple[Any, ...]:
xnumel = list(x_blocks_list)
ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
if dynamic_shape:
ynumel = None if None in ynumel else ynumel
znumel = None if None in znumel else znumel
else:
# TODO: improve 1d/2d mixed cases
ynumel = (
None
if any(e is None for e in cast(list[Any], ynumel))
else max(cast(Iterable[int], ynumel))
)
znumel = (
None
if any(e is None for e in cast(list[Any], znumel))
else max(cast(Iterable[int], znumel))
)
numels = (
(xnumel,)
if not ynumel
else (ynumel, xnumel)
if not znumel
else (znumel, ynumel, xnumel)
)
return numels
class RoundRobinDispatch:
"""
The dispatcher which dispatches the subkernels in a round robin manner:
@ -368,6 +336,8 @@ class ComboKernel(Kernel):
grid(...): codegen the grid size for launching the combo kernel.
"""
grid_expr = RoundRobinComboKernelGrid
@classmethod
def codegen_pid_range(
cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
@ -381,51 +351,6 @@ class ComboKernel(Kernel):
with code.indent():
code.splice(f"pid_offset = pid // {num_kernels}")
@classmethod
def grid(
cls,
sub_kernel_numels: list[list[int]],
x_blocks_list: list[Union[str, int]],
dynamic_shape: bool,
) -> tuple[Any, ...]:
xnumel = x_blocks_list
# set no_x_dim xnumels to 0
xnumel_x_dim = [max(e, 0) for e in xnumel]
ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
# TODO: support 1d/2d mixed cases
xnumel = (
None
if any(e is None for e in xnumel)
else xnumel
if dynamic_shape
else max(xnumel_x_dim) # type: ignore[type-var, arg-type]
)
ynumel = (
None
if any(e is None for e in ynumel)
else ynumel
if dynamic_shape
else max(ynumel) # type: ignore[type-var, arg-type]
)
znumel = (
None
if any(e is None for e in znumel)
else znumel
if dynamic_shape
else max(znumel) # type: ignore[type-var, arg-type]
)
numels = (
(xnumel,)
if not ynumel
else (ynumel, xnumel)
if not znumel
else (znumel, ynumel, xnumel)
)
return numels
def __init__(
self, enable_autotune: bool = False, mixed_sizes: bool = False
) -> None:
@ -638,7 +563,7 @@ class ComboKernel(Kernel):
# mixed_sizes is used for optimize_mask, so it only allows sequential dispatch
# Not mixed sizes on y dim technically is ok to use round robin as wells.
if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list):
# str in min_x_blocks_list means a dynamic shape
# str in x_numels_list means a dynamic shape
self.dispatch_class = ComboKernel.SequentialDispatch
return
# A negative x_blocks_list element means the kernel is not tunable,
@ -675,7 +600,11 @@ class ComboKernel(Kernel):
}
triton_meta["configs"] = [config_of(signature)]
mutated_args = self.get_mutated_args_sub_kernels()
dispatch = self.dispatch_class
assert dispatch is not None
inductor_meta = {
"grid_type": dispatch.grid_expr.__name__,
"combo_grid_meta": self.combo_grid_meta(),
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
"mutated_arg_names": mutated_args,
**TritonKernel.inductor_meta_common(),
@ -771,8 +700,8 @@ class ComboKernel(Kernel):
self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
return argdefs
def add_numel_to_call_args_and_grid(
self, name: str, call_args: list[Any], arg_types: list[Any], grid: list[Any]
def add_numel_to_call_args(
self, name: str, call_args: list[Any], arg_types: list[Any]
) -> None:
for num, sub_kernel in enumerate(self.sub_kernels):
for i, tree in enumerate(sub_kernel.range_trees):
@ -785,40 +714,20 @@ class ComboKernel(Kernel):
expr = V.graph.wrapper_code.generate_numel_expr(
name, tree, suffix=str(num)
)
if not tree.is_reduction:
assert isinstance(grid[i][num], str), (
f"Grid {grid[i][num]} should be a dynamic shape."
)
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert grid[i][num] == numel_sign + numel_name, (
f"numel args mismatch: {grid[i][num]} vs {numel_name}"
)
grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction:
call_args.append(expr)
arg_types.append(type(expr))
def add_numel_to_call_args_and_grid_benchmark(
self, extra_args: list[Any], grid: Union[list[Any], tuple[Any, ...]]
) -> None:
def kernel_benchmark_extra_args(self) -> list[str]:
extra_args = []
for num, sub_kernel in enumerate(self.sub_kernels):
for i, tree in enumerate(sub_kernel.range_trees):
numel_name = f"{tree.prefix}numel_{num}"
if numel_name not in self.dynamic_shape_args:
continue
expr = V.graph.sizevars.size_hint(tree.numel)
if not tree.is_reduction:
assert isinstance(grid[i][num], str), (
f"Grid {grid[i][num]} should be a dynamic shape."
)
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert grid[i][num] == numel_sign + numel_name, (
f"grid mismatch: {grid[i][num]} vs {numel_name}"
)
grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append(expr)
extra_args.append(str(V.graph.sizevars.size_hint(tree.numel)))
return extra_args
def codegen_kernel(self, name: Optional[str] = None) -> str:
# TODO: is it correct to use the first sub kernel's heuristics?
@ -890,12 +799,9 @@ class ComboKernel(Kernel):
return code.getvalue()
def codegen_kernel_benchmark(
self, num_gb: float, grid: Optional[list[Any]] = None
) -> IndentedBuffer:
def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
result = IndentedBuffer()
_argdefs, call_args, signature, _ = self.args.python_argdefs()
result.writelines(["", "", "def get_args():"])
with result.indent():
name_cnt = itertools.count()
@ -934,38 +840,11 @@ class ComboKernel(Kernel):
f"Don't find the buffer or const tensor for {arg_name}"
)
var_names.append(var_name)
if self.dynamic_shape_args:
var_names.extend(self.kernel_benchmark_extra_args())
result.writeline(f"return {', '.join(var_names)},")
result.writelines(["\n", "\n", "def call(args):"])
if grid is None:
assert self.dispatch_class is not None
dynamic_shape = self.dynamic_shape_args != []
grid_tuple = self.dispatch_class.grid(
self.grids, self.x_numels_list, dynamic_shape
)
extra_args_str = ""
extra_args: list[Any] = []
if dynamic_shape:
self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple)
# convert nested list to list of str
grid_tuple = tuple(
"[" + ", ".join(pexpr(item) for item in e) + ",]"
for e in grid_tuple
)
extra_args_str = ", ".join(map(str, extra_args)) + ", "
min_blocks = None
else:
min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels)
grid_str = ", ".join(pexpr(item) for item in grid_tuple)
grid_extra_kwargs = (
f"num_kernels={len(self.sub_kernels)}, "
f"min_blocks={min_blocks}, "
f"is_sequential={self.dispatch_class is self.SequentialDispatch}"
)
grid_str = f"{grid_str}, {grid_extra_kwargs}"
grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})"
else:
grid_arg = f"grid={grid}"
index = V.graph.get_current_device_or_throw().index
with result.indent():
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
@ -976,7 +855,7 @@ class ComboKernel(Kernel):
stream_name = f"stream{index}"
result.writeline(f"{stream_name} = get_raw_stream({index})")
result.writeline(
f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})"
)
# benchmark all configs
@ -988,7 +867,7 @@ class ComboKernel(Kernel):
V.graph.device_ops.set_device(index)
) # no-op to ensure context
result.writeline(
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)"
)
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
@ -1016,7 +895,6 @@ class ComboKernel(Kernel):
from torch._dynamo.testing import rand_strided
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
)
@ -1053,77 +931,48 @@ class ComboKernel(Kernel):
wrapper = V.graph.wrapper_code
assert self.dispatch_class is not None
dynamic_shape = self.dynamic_shape_args != []
grid = list(
self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape)
if self.dynamic_shape_args:
self.add_numel_to_call_args(name, call_args, arg_types)
wrapper.generate_kernel_call(
name,
call_args,
triton=True,
arg_types=arg_types,
)
def combo_grid_meta(self) -> dict[str, Any]:
dynamic_shape = bool(self.dynamic_shape_args)
num_kernels = len(self.sub_kernels)
min_blocks = (
max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None
)
is_sequential = self.dispatch_class is self.SequentialDispatch
if dynamic_shape:
self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
# convert nested list to list of str
# grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid)
if not self.enable_autotune and not dynamic_shape:
launch_grid = self.grid_no_autotune(
grid, num_kernels, cast(int, min_blocks), is_sequential
)
V.graph.wrapper_code.generate_kernel_call(
name,
call_args,
grid=launch_grid,
arg_types=arg_types,
grid_fn="",
)
return
# autotuning is enabled
grid = wrapper.generate_default_grid(
name,
list(grid),
grid_callable=grid_combo_kernels,
num_kernels=num_kernels,
min_blocks=min_blocks,
is_sequential=is_sequential,
default_meta=None if self.enable_autotune else self.get_default_meta(),
)
wrapper.generate_kernel_call(
name,
call_args,
grid,
V.graph.get_current_device_or_throw().index,
gpu=True,
triton=True,
arg_types=arg_types,
grid_fn="grid_combo_kernels",
grid_extra_kwargs=(
f"num_kernels={num_kernels}, "
f"min_blocks={min_blocks}, "
f"is_sequential={is_sequential}, "
f"default_meta={None if self.enable_autotune else self.get_default_meta()}"
),
)
def grid_no_autotune(
self,
grid: Union[tuple[Any], list[Any]],
num_kernels: int,
min_blocks: int,
is_sequential: bool,
) -> list[int]:
meta = self.get_default_meta()
grid_func = grid_combo_kernels(
*grid,
num_kernels=num_kernels,
min_blocks=min_blocks,
is_sequential=is_sequential,
)
return grid_func(meta)
def get_default_meta(self) -> dict[str, int]:
if "YBLOCK" in self.block_args:
meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d}
if not self.enable_autotune:
if "YBLOCK" in self.block_args:
default_config = {
"XBLOCK": self.block_size_2d,
"YBLOCK": self.block_size_2d,
}
else:
default_config = {"XBLOCK": self.block_size_1d}
else:
meta = {"XBLOCK": self.block_size_1d}
default_config = None
meta = {
"num_kernels": num_kernels,
"min_blocks": min_blocks,
"default_config": default_config,
}
for num, sub_kernel in enumerate(self.sub_kernels):
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
for i, tree in enumerate(sub_kernel.range_trees):
if not tree.is_reduction:
numel_name = f"{tree.prefix}numel_{num}"
if numel_name in self.dynamic_shape_args:
meta[numel_name] = None
else:
meta[numel_name] = int(V.graph.sizevars.simplify(tree.numel))
return meta

View File

@ -11,7 +11,7 @@ from torch._inductor.codegen.triton import (
TritonCSEVariable,
TritonKernel,
)
from torch._inductor.runtime.triton_heuristics import split_scan_grid
from torch._inductor.runtime.triton_heuristics import SplitScanGrid
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv
@ -203,5 +203,5 @@ class TritonSplitScanKernel(TritonKernel):
def _get_heuristic(self):
return "split_scan"
def _get_grid_fn(self):
return split_scan_grid
def _get_grid_type(self) -> type[SplitScanGrid]:
return SplitScanGrid

View File

@ -49,6 +49,7 @@ from ..utils import (
LineContext,
sympy_product,
sympy_str,
sympy_subs,
triton_version_uses_attrs_dict,
)
from ..virtualized import V
@ -61,6 +62,7 @@ from .common import (
WorkspaceArg,
WorkspaceZeroMode,
)
from .cpp_utils import cexpr
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
@ -839,14 +841,7 @@ class PythonWrapperCodegen(CodeGen):
import_str = f"""
import triton
import triton.language as tl
from {triton_heuristics.__name__} import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from {triton_heuristics.__name__} import start_graph, end_graph
"""
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.splice(import_str)
@ -1133,44 +1128,6 @@ class PythonWrapperCodegen(CodeGen):
with debug_printer_manager:
self.writeline(f"{kernel}({', '.join(args)})")
def generate_user_defined_triton_kernel(
self,
kernel_name: str,
raw_args: list[Any],
grid: list[Any],
configs,
triton_meta,
constexprs,
):
grid_fn, code = user_defined_kernel_grid_fn_code(
kernel_name, configs, grid, wrapper=self
)
if not (config.triton.autotune_at_compile_time and V.graph.cpp_wrapper):
# When codegen the autotune block only, do no insert Triton kernel
# code into the main block
#
# Must happen after free symbols are already codegened
# Emit the grid wrapper function right before the call
for line in code.split("\n"):
self.writeline(line)
# Explicitly call the Python version of val_to_arg_str
args = [PythonWrapperCodegen.val_to_arg_str(self, v) for v in raw_args]
arg_types = [
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
for arg in raw_args
]
# Because generate_kernel_call can be overriden by a subclass, explicitly call
# PythonWrapperCodegen.generate_kernel_call here
PythonWrapperCodegen.generate_kernel_call(
self,
kernel_name,
args,
grid_fn=grid_fn,
arg_types=arg_types,
raw_args=raw_args,
)
def _generate_tma_descriptor_call(self, desc, apply_size_hints=False):
dims = desc.dims
block_dims = desc.block_dims
@ -1694,13 +1651,15 @@ class PythonWrapperCodegen(CodeGen):
kwargs,
restore_value_args,
reset_to_zero_args,
grids: list[list[Union[int, sympy.Expr]]],
):
from torch.utils._triton import patch_triton_dtype_repr
patch_triton_dtype_repr()
original_name = kernel.__name__
from ..runtime.triton_heuristics import (
config_to_dict,
FixedGrid,
PrecomputedGrid,
)
from .common import (
ConstexprArg,
KernelArgType,
@ -1708,7 +1667,10 @@ class PythonWrapperCodegen(CodeGen):
TensorArg,
TMADescriptorArg,
)
from .triton import gen_common_triton_imports, TritonKernel
patch_triton_dtype_repr()
original_name = kernel.__name__
signature: list[KernelArgType] = []
constants: dict[str, Any] = {}
arg_indices: list[int] = []
@ -1839,22 +1801,67 @@ class PythonWrapperCodegen(CodeGen):
if reset_to_zero_args:
triton_meta["reset_to_zero"] = tuple(reset_to_zero_args)
if len(grids) == 1:
# compute the grid in the wrapper and pass it in as an arg
inductor_meta: dict[str, Any] = FixedGrid.setup_grid_as_args()
extra_launcher_call_args = [*map(sympy.sympify, grids[0])]
else:
def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr:
if isinstance(expr, sympy.Expr):
symbols = [*expr.free_symbols]
if not symbols:
return expr
symbols.sort(key=str)
for sym in symbols:
if sym in extra_launcher_args:
continue
extra_launcher_args[sym] = sympy.Symbol(
f"_launcher_s{len(extra_launcher_args)}"
)
return sympy_subs(expr, extra_launcher_args)
assert isinstance(expr, int)
return sympy.Integer(expr)
extra_launcher_args: dict[sympy.Symbol, sympy.Symbol] = {}
grids = [[*map(rename_sizes_for_launcher, grid)] for grid in grids]
assert grids and len(grids) == len(configs)
precomputed_grids = []
for grid, cfg in sorted(
zip(grids, configs), key=lambda x: len(x[1].kwargs), reverse=True
):
precomputed_grids.append(
{
"config": config_to_dict(cfg),
"python": [*map(pexpr, grid)],
"cpp": [*map(cexpr, grid)],
}
)
inductor_meta = {
"grid_type": PrecomputedGrid.__name__,
"precomputed_grids": precomputed_grids,
"extra_launcher_args": [*map(str, extra_launcher_args.values())],
}
extra_launcher_call_args = [*extra_launcher_args.keys()]
# Distinguish between different functions using function id
cache_key: list[Any] = [id(kernel.fn)]
cache_key: Any = [id(kernel.fn)]
if len(configs) > 0:
for arg in kwargs.values():
# We need to key on non tensor arg only in autotune mode
if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
cache_key.append(arg)
cache_key.append(str(triton_meta))
cache_key.extend(str(inductor_meta))
cache_key = tuple(cache_key)
if cache_key in self.user_defined_kernel_cache:
return self.user_defined_kernel_cache[cache_key]
return (
*self.user_defined_kernel_cache[cache_key],
extra_launcher_call_args,
)
name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
# Add to the cache for the next use
self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
compile_wrapper = IndentedBuffer()
if config.triton.unique_user_kernel_names:
@ -1862,28 +1869,14 @@ class PythonWrapperCodegen(CodeGen):
else:
compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
from .triton import gen_common_triton_imports, TritonKernel
inductor_meta["kernel_name"] = name
inductor_meta.update(TritonKernel.inductor_meta_common())
compile_wrapper.splice(gen_common_triton_imports())
inductor_meta = {
"kernel_name": name,
**TritonKernel.inductor_meta_common(),
}
configs = [
{
"kwargs": config.kwargs,
"num_warps": config.num_warps,
"num_stages": config.num_stages,
}
for config in configs
]
compile_wrapper.splice(
f"""
@triton_heuristics.user_autotune(
configs={configs!r},
configs={[*map(config_to_dict, configs)]!r},
inductor_meta={inductor_meta!r},
triton_meta={triton_meta!r},
filename=__file__,
@ -1908,7 +1901,9 @@ class PythonWrapperCodegen(CodeGen):
compile_wrapper.getvalue(),
metadata,
)
return name, triton_meta
# Add to the cache for the next use
self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
return name, triton_meta, extra_launcher_call_args
def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None):
expr = f"{kernel_name}_{tree.prefix}numel"
@ -2021,17 +2016,7 @@ class PythonWrapperCodegen(CodeGen):
"""
)
def generate_default_grid(
self,
kernel_name: str,
grid_args: list[Any],
gpu: bool = True,
grid_callable: Optional[Callable[..., Any]] = None,
**grid_extra_kwags,
):
return grid_args
def prepare_triton_kernel_call(self, device_index, call_args):
def prepare_triton_kernel_call(self, call_args):
def wrap_arg(arg):
if isinstance(arg, str):
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
@ -2041,13 +2026,7 @@ class PythonWrapperCodegen(CodeGen):
else:
return pexpr(V.graph.sizevars.simplify(arg))
call_args = [wrap_arg(arg) for arg in call_args]
if device_index is None:
current_device = V.graph.get_current_device_or_throw()
device_index = current_device.index
return device_index, call_args
return [wrap_arg(arg) for arg in call_args]
def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None):
if isinstance(arg_type, torch_dtype):
@ -2144,35 +2123,28 @@ class PythonWrapperCodegen(CodeGen):
self,
kernel_name: str,
call_args,
grid=None,
device_index=None,
gpu=True,
*,
device=None,
triton=True,
arg_types=None,
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
"""
Generates kernel call code.
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True,
and C++ when gpu=False.
"""
if not (triton or gpu):
device = device or V.graph.get_current_device_or_throw()
if not (triton or device.type != "cpu"):
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
return
device_index, call_args_str = self.prepare_triton_kernel_call(
device_index, call_args
)
call_args_str = self.prepare_triton_kernel_call(call_args)
call_args_str = ", ".join(call_args_str)
stream_name = PythonWrapperCodegen.write_get_raw_stream(
self, device_index, V.graph
self, device.index, V.graph
)
if not triton:
stream_ptr = f"c_void_p({stream_name})"
@ -2227,17 +2199,8 @@ class PythonWrapperCodegen(CodeGen):
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i)
all_args.append(arg_str if key is None else f"{key}={arg_str}")
if grid is None:
grid_str = grid_fn
else:
grid_str = ", ".join(
self.generate_example_arg_value(g, type(g)) for g in grid
)
if grid_extra_kwargs:
grid_str = f"{grid_str}, {grid_extra_kwargs}"
grid_str = f"{grid_fn}({grid_str})"
self.kernel_autotune_calls.writeline(
f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})"
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
)
self.kernel_autotune_calls.writeline(
f"del {', '.join(arg for arg in tensor_args.values())}\n",
@ -2247,22 +2210,11 @@ class PythonWrapperCodegen(CodeGen):
# For cpp wrapper, no need to continue codegen for the main body
return
if grid is None:
grid_str = grid_fn
else:
grid_str = ", ".join(
PythonWrapperCodegen._grid_dim_str(self, item) for item in grid
)
if grid_extra_kwargs:
grid_str = f"{grid_str}, {grid_extra_kwargs}"
grid_str = f"{grid_fn}({grid_str})"
# add debug printer code for triton kernel calls at (jit) inductor level
debug_printer_manager = V.graph.wrapper_code.debug_printer
debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
with debug_printer_manager:
self.writeline(
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
)
self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})")
def writeline(self, line):
self.lines.append(line)

View File

@ -40,25 +40,7 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
return source_codes
def kernel_driver(self) -> str:
source_codes = """
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
"""
return source_codes
return ""
def cpp_stream_type(self) -> str:
return "sycl::queue*"

View File

@ -1143,8 +1143,6 @@ class _InProcessFxCompile(FxCompile):
serialized_extern_kernel_nodes,
)
additional_files = graph.wrapper_code.additional_files
with dynamo_timed(
"AotCodeCompiler.compile", log_pt2_compile_event=True
):
@ -1154,7 +1152,11 @@ class _InProcessFxCompile(FxCompile):
code,
serialized_extern_kernel_nodes,
device_type=graph.device_type,
additional_files=additional_files,
additional_files=[
*dict.fromkeys(
graph.wrapper_code.additional_files
)
],
)
else:
compiled_fn = graph.compile_to_module().call

View File

@ -27,6 +27,7 @@ from ..pattern_matcher import (
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
SymbolicGridFn,
TritonTemplate,
TritonTemplateCaller,
)
@ -38,8 +39,9 @@ B2B_GEMM_PASS = PatternMatcherPass(
)
def b2b_gemm_grid(M, P, meta):
return (ceildiv(M, meta["BLOCK_SIZE_M"]) * ceildiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
@SymbolicGridFn
def b2b_gemm_grid(M, P, meta, *, cdiv):
return (cdiv(M, meta["BLOCK_SIZE_M"]) * cdiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
b2b_gemm_left_template = TritonTemplate(

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import contextlib
import copy
import dataclasses
import functools
import itertools
@ -5833,82 +5832,68 @@ class UserDefinedTritonKernel(ExternKernel):
) = self.get_kernel_and_metadata()
# Definition of kernel
new_name, triton_meta = wrapper.define_user_defined_triton_kernel(
kernel, configs, self.kwargs, restore_value_args, reset_to_zero_args
)
raw_args = [
self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
]
# NOTE: raw_args doesn't include autotuned args.
# But, kernel.constexprs includes indices of autotuned args.
# So, let's recalculate constexpr indices wrt to raw_args.
constexpr_indices = []
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
if kernel.arg_names.index(kwarg) in kernel.constexprs:
constexpr_indices.append(idx)
# Create a copy of triton_meta to avoid modifying the original version.
triton_meta = copy.deepcopy(triton_meta)
if not triton_version_uses_attrs_dict():
"""
Filter out None args.
see https://github.com/pytorch/pytorch/issues/115344
Two cases for a None arg:
1. The arg is already tl.constexpr, so leave it in
2. The arg is not tl.constexpr so we have to remove it
"""
constexpr_indices_set = OrderedSet(constexpr_indices)
REMOVED = object()
raw_args = [
(
(idx, arg)
if (arg is not None)
or (arg is None and idx in constexpr_indices_set)
else (idx, REMOVED)
)
for idx, arg in enumerate(raw_args)
]
removed_none_args = [idx for idx, val in raw_args if val == REMOVED]
raw_args = [val for idx, val in raw_args if val != REMOVED]
# We have to compute the constexpr indices for the new, filtered raw_args
# We also have to adjust equal_to_1.
if removed_none_args:
eq1_indices_set = OrderedSet[int](triton_meta["configs"][0].equal_to_1)
constexpr_indices = []
equal_to_1 = []
index_shift = 0
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
# every time we encounter an idx we removed, adjust by one to account for it
# So for example if we had [None, const X]
# iter 1:
# None was removed, adjust=1
# iter 2:
# X is const at idx=1, but the adjusted idx is 0 now, because None was removed
if idx in removed_none_args:
index_shift += 1
continue
arg_index = kernel.arg_names.index(kwarg)
if arg_index in kernel.constexprs:
constexpr_indices.append(idx - index_shift)
if arg_index in eq1_indices_set:
equal_to_1.append(idx - index_shift)
triton_meta["configs"][0].equal_to_1 = equal_to_1
# Call to kernel
self.codegen_comment(wrapper)
wrapper.generate_user_defined_triton_kernel(
(
new_name,
raw_args,
self.grid,
configs,
triton_meta,
constexpr_indices,
extra_launch_args,
) = wrapper.define_user_defined_triton_kernel(
kernel,
configs,
self.kwargs,
restore_value_args,
reset_to_zero_args,
self.grid,
)
named_args = {
k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
}
constexpr_names = OrderedSet([kernel.arg_names[i] for i in kernel.constexprs])
args: list[Any] = []
arg_types: list[Any] = []
raw_args_filtered: list[Any] = []
for name, arg in itertools.chain(
named_args.items(), zip(itertools.repeat(""), extra_launch_args)
):
raw_args_filtered.append(arg)
if isinstance(arg, IRNode):
args.append(arg.codegen_reference())
arg_types.append(arg.get_dtype())
elif isinstance(arg, (int, float, bool, sympy.Expr)):
args.append(arg)
arg_types.append(type(arg))
elif name in constexpr_names:
# insert a dummy value for constexpr args of unsupported type
# constexprs will end up getting baked into the kernel at compile time
args.append(-1)
arg_types.append(int)
elif arg is None:
"""
Filter out None args.
see https://github.com/pytorch/pytorch/issues/115344
Two cases for a None arg:
1. The arg is already tl.constexpr, so leave it in
2. The arg is not tl.constexpr so we have to remove it
"""
if triton_version_uses_attrs_dict():
args.append(-1)
arg_types.append(int)
else:
raw_args_filtered.pop()
else:
raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
self.codegen_comment(wrapper)
wrapper.generate_kernel_call(
new_name,
args,
arg_types=arg_types,
raw_args=raw_args_filtered,
triton_meta=triton_meta,
triton=True,
device=self.get_device(),
)
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:

View File

@ -8,10 +8,10 @@ from .. import ir, lowering as L
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
SymbolicGridFn,
TritonTemplate,
)
from ..utils import (
ceildiv as cdiv,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_cpp_bmm_template,
@ -33,7 +33,8 @@ log = logging.getLogger(__name__)
aten = torch.ops.aten
def bmm_grid(b, m, n, meta):
@SymbolicGridFn
def bmm_grid(b, m, n, meta, *, cdiv):
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)

View File

@ -17,10 +17,10 @@ from ..lowering import (
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
SymbolicGridFn,
TritonTemplate,
)
from ..utils import (
ceildiv,
is_ones,
is_zeros,
pad_listlike,
@ -43,18 +43,20 @@ log = logging.getLogger(__name__)
aten = torch.ops.aten
def conv2d_grid(n, c, h, w, meta):
@SymbolicGridFn
def conv2d_grid(n, c, h, w, meta, *, cdiv):
return (
ceildiv(n * h * w, meta["BLOCK_M"]),
ceildiv(c, meta["BLOCK_N"]),
cdiv(n * h * w, meta["BLOCK_M"]),
cdiv(c, meta["BLOCK_N"]),
meta["GROUPS"],
)
def conv3d_grid(n, c, d, h, w, meta):
@SymbolicGridFn
def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
return (
ceildiv(n * d * h * w, meta["BLOCK_M"]),
ceildiv(c, meta["BLOCK_N"]),
cdiv(n * d * h * w, meta["BLOCK_M"]),
cdiv(c, meta["BLOCK_N"]),
meta["GROUPS"],
)

View File

@ -46,7 +46,12 @@ from ..lowering import (
register_lowering,
to_dtype,
)
from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
from ..select_algorithm import (
autotune_select_algorithm,
realize_inputs,
SymbolicGridFn,
TritonTemplate,
)
log = logging.getLogger(__name__)
@ -91,15 +96,14 @@ def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
return construct_strides(size, fill_order)
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
@SymbolicGridFn
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv):
"""How is this kernel parallelized?
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
Each block is responsible for iterating over blocks of keys and values calculating
the final attention output.
"""
import triton
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
def create_placeholder(

View File

@ -12,7 +12,7 @@ from .. import config, ir
from ..ir import FixedLayout, FlexibleLayout
from ..lowering import empty, empty_strided, lowerings
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from ..select_algorithm import autotune_select_algorithm, SymbolicGridFn, TritonTemplate
from .flex_attention import (
compute_forward_block_mn,
compute_forward_inner,
@ -30,6 +30,7 @@ aten = torch.ops.aten
prims = torch.ops.prims
@SymbolicGridFn
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
"""How is this kernel parallelized?
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)

View File

@ -8,7 +8,7 @@ from typing import Any, cast
import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
@ -17,7 +17,6 @@ from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import ChoiceCaller, Layout
from ..runtime.runtime_utils import next_power_of_2
from ..utils import (
ceildiv as cdiv,
get_backend_num_stages,
get_num_sms,
TMA_DESCRIPTOR_SIZE,
@ -442,14 +441,16 @@ def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool:
return fallback_to_aten
def mm_grid(m, n, meta):
@SymbolicGridFn
def mm_grid(m, n, meta, *, cdiv):
"""
The CUDA grid size for matmul triton templates.
"""
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
def persistent_mm_grid(M: int, N: int, meta: dict[str, Any]):
@SymbolicGridFn
def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
"""Defines the grid for persistent kernels."""
return (
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),

View File

@ -3,9 +3,11 @@ from __future__ import annotations
import builtins
import copy
import dataclasses
import functools
import hashlib
import inspect
import itertools
import logging
import math
import operator
@ -16,7 +18,7 @@ import sys
import threading
import time
from collections import namedtuple
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
import torch
from torch._prims_common import compute_required_storage_length
@ -71,7 +73,7 @@ class NoTritonConfigsError(RuntimeError):
if TYPE_CHECKING:
from collections.abc import Container, Hashable
from collections.abc import Container, Hashable, Sequence
LauncherType = Any
@ -136,7 +138,7 @@ def disable_pointwise_autotuning(inductor_meta):
return not inductor_meta.get("autotune_pointwise", True)
def _dump_launch_params(args, kwargs, launcher, kernel_name):
def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
call_args = []
call_kwargs = {}
for arg in args:
@ -154,14 +156,12 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name):
call_kwargs[k] = v
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
args_str = ""
args_str += ", ".join(call_args)
for k, v in call_kwargs.items():
args_str += f", {k}={v}"
args_str = [*call_args]
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
args_str = ", ".join(args_str)
abs_path = os.path.abspath(sys.argv[0])
with open(f"{abs_path}.launch_params", "a") as f:
f.write(f"{kernel_name} | {args_str}\n")
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
class CachingAutotuner(KernelInterface):
@ -478,6 +478,12 @@ class CachingAutotuner(KernelInterface):
if k in cfg_kwargs:
compile_meta[k] = cfg_kwargs.pop(k)
compile_meta["constants"].update(cfg_kwargs)
for i in self.fn.constexprs:
arg_name = self.fn.arg_names[i]
if arg_name not in compile_meta["constants"] and (
arg_name == "num_warps" or arg_name == "num_stages"
):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
compile_meta["debug"] = self.inductor_meta.get(
@ -563,7 +569,7 @@ class CachingAutotuner(KernelInterface):
return new_args
return args
def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
def bench(self, launcher, *args, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
# we don't skip configs with spilled registers when auto-tuning custom
# (user-written) Triton kernels, as (i) we don't have any knowledge or
@ -595,7 +601,6 @@ class CachingAutotuner(KernelInterface):
launcher(
*args_with_constexprs,
**cloned_kwargs,
grid=grid,
stream=stream,
)
self.restore_args_from_cpu(cpu_copies)
@ -650,8 +655,8 @@ class CachingAutotuner(KernelInterface):
else:
budget -= size
for i, arg in enumerate(args):
maybe_copy(self.fn.arg_names[i], arg)
for name, arg in zip(self.fn.arg_names, args):
maybe_copy(name, arg)
for name, arg in kwargs.items():
maybe_copy(name, arg)
@ -712,10 +717,10 @@ class CachingAutotuner(KernelInterface):
return arg
cloned_args = [
prepare_arg(self.fn.arg_names[i], arg) for i, arg in enumerate(args)
prepare_arg(name, arg)
for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args)
]
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
return cloned_args, cloned_kwargs
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
@ -777,12 +782,7 @@ class CachingAutotuner(KernelInterface):
if self.save_cache_hook:
self.save_cache_hook(launcher.config, self.autotune_time_taken_ns)
def save_gpu_kernel(self, grid, stream, launcher):
if callable(grid):
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
else:
grid_x, grid_y, grid_z = grid
def save_gpu_kernel(self, stream, launcher):
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
assert key is not None, "kernel_name can not be None"
params = {
@ -791,13 +791,6 @@ class CachingAutotuner(KernelInterface):
if hasattr(launcher.bin.metadata, "name")
else launcher.bin.metadata["name"]
),
"grid_x": grid_x,
"grid_y": grid_y,
"grid_z": grid_z,
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
"y_block": launcher.config.kwargs.get("YBLOCK", None),
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
"r_block": launcher.config.kwargs.get("RBLOCK", None),
"num_warps": (
launcher.bin.num_warps
if hasattr(launcher.bin, "num_warps")
@ -810,7 +803,11 @@ class CachingAutotuner(KernelInterface):
),
"stream": stream,
# User defined triton kernels will have arbitrary kwarg names
"meta": launcher.config.kwargs,
"config": config_to_dict(launcher.config),
"inductor_meta": self.inductor_meta,
"triton_meta": self.triton_meta,
"def_args": launcher.def_args,
"call_args": launcher.call_args,
}
from torch._inductor.codecache import CudaKernelParamCache
@ -888,8 +885,15 @@ class CachingAutotuner(KernelInterface):
)
return config2launcher.get(best_config)
def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
def run(
self,
*args,
stream,
benchmark_run=False,
**kwargs,
): # type:ignore[override]
if self.triton_interpret:
args, grid = self._interpret_args_grid(args, self.configs[0])
return self.fn[grid](
*args,
**kwargs,
@ -902,35 +906,28 @@ class CachingAutotuner(KernelInterface):
self.precompile()
self.precompile_time_taken_ns = time.time_ns() - start_time
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid, **kwargs)
self.autotune_to_one_config(*args, **kwargs)
if not getattr(
self.launchers[0].config, "found_by_coordesc", False
) and self.inductor_meta.get("coordinate_descent_tuning", False):
self.launchers = [
self.coordinate_descent_tuning(
self.launchers[0], *args, grid=grid, **kwargs
)
self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs)
]
(launcher,) = self.launchers
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
self.save_gpu_kernel(grid, stream, launcher)
self.save_gpu_kernel(stream, launcher)
args = self._get_args_with_constexprs(args, launcher)
if self.dump_launch_params:
_dump_launch_params(args, kwargs, launcher, self.fn.__name__)
new_args, grid = self._interpret_args_grid(args, launcher.config)
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
# it is faster than entering and exiting a context manager, even if the context
# manager is a nullcontext.
if autograd_profiler._is_profiler_enabled:
# grid can be a tuple of ints or a string.
if isinstance(grid, tuple):
grid_info = str(grid)
else:
grid_info = getattr(grid, "grid_fn_str", "")
kernel_kwargs_str = ",".join(
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
)
@ -939,7 +936,6 @@ class CachingAutotuner(KernelInterface):
"kernel_file": (self.filename or ""),
"kernel_hash": self.kernel_hash,
"kernel_backend": "triton",
"grid": grid_info,
"stream": stream,
"num_warps": launcher.config.num_warps,
"num_stages": launcher.config.num_stages,
@ -954,17 +950,33 @@ class CachingAutotuner(KernelInterface):
return launcher(
*args,
**kwargs,
grid=grid,
stream=stream,
)
else:
return launcher(
*args,
**kwargs,
grid=grid,
stream=stream,
)
def _interpret_args_grid(
self, args: tuple[Any, ...], cfg: Config
) -> tuple[tuple[Any, ...], tuple[int, int, int]]:
grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow(
dict(
zip(
[
*self.fn.arg_names,
*self.inductor_meta.get("extra_launcher_args", ()),
],
args,
)
)
)
if self.inductor_meta.get("extra_launcher_args"):
args = args[: -len(self.inductor_meta["extra_launcher_args"])]
return args, grid
class _ConstRepr:
def __init__(self, value: str):
@ -1110,10 +1122,11 @@ class TritonCompileResult:
for i, arg in enumerate(fn.arg_names)
if i not in fn.constexprs and arg not in none_args
]
cfg_dict = config_to_dict(cfg)
def_args = [
name
for name in fn.arg_names
if name not in cfg.kwargs and name not in none_args
if name not in cfg_dict and name not in none_args
]
binary_shared = (
@ -1176,9 +1189,7 @@ class TritonCompileResult:
# we want to burn None in to the launch args with zero overhead.
# See https://github.com/pytorch/pytorch/issues/123597
if binary.__class__.launch_enter_hook:
launch_metadata = (
f"bin.launch_metadata(grid, stream, {', '.join(call_args)})"
)
launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})"
else:
launch_metadata = "None"
runner_args = [
@ -1194,18 +1205,20 @@ class TritonCompileResult:
*call_args,
]
exec(
f"""
def launcher({", ".join(def_args)}, grid, stream):
if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
runner({", ".join(runner_args)})
return bin
""".lstrip(),
scope,
)
if "extra_launcher_args" in self.inductor_meta:
def_args.extend(self.inductor_meta["extra_launcher_args"])
grid = GridExpr.from_meta(self.inductor_meta, cfg)
# grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)`
lines = [
f"def launcher({', '.join(def_args)}, stream):",
*[f" {line}" for line in grid.prefix],
f" grid_0 = {grid.x_grid}",
f" grid_1 = {grid.y_grid}",
f" grid_2 = {grid.z_grid}",
f" runner({', '.join(runner_args)})",
]
exec("\n".join(lines), scope)
launcher = scope["launcher"]
launcher.config = cfg
@ -1217,6 +1230,8 @@ class TritonCompileResult:
if launcher.store_cubin:
launcher.fn = fn
launcher.bin = binary
launcher.def_args = def_args
launcher.call_args = call_args
return launcher
@ -1306,9 +1321,9 @@ class DebugAutotuner(CachingAutotuner):
super().__init__(*args, **kwargs)
self.cached = None
def run(self, *args, grid, stream, **kwargs):
def run(self, *args, stream, **kwargs):
if not self.with_bandwidth_info:
super().run(*args, grid=grid, stream=stream, **kwargs, benchmark_run=True)
super().run(*args, stream=stream, **kwargs, benchmark_run=True)
return
else:
possible_names = _find_names(self)
@ -1322,16 +1337,14 @@ class DebugAutotuner(CachingAutotuner):
self.precompile()
self.precompile_time_taken_ns = time.time_ns() - start_time
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid, **kwargs)
self.autotune_to_one_config(*args, **kwargs)
(launcher,) = self.launchers
if launcher.store_cubin:
self.save_gpu_kernel(grid, stream, launcher)
self.save_gpu_kernel(stream, launcher)
if self.cached is None:
ms = self.bench(
launcher, *args, grid=grid, with_profiler=self.with_profiler
)
ms = self.bench(launcher, *args, with_profiler=self.with_profiler)
num_in_out_ptrs = len(
[
arg_name
@ -2127,6 +2140,19 @@ def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
return popped
def config_to_dict(config: Config) -> dict[str, Any]:
return {
**config.kwargs,
"num_warps": config.num_warps,
"num_stages": config.num_stages,
}
def config_from_dict(config: dict[str, Any]) -> Config:
config = {**config}
return Config(config, **_pop_config_kwargs(config))
def fixed_config(config, filename, triton_meta, inductor_meta):
"""
Used when the configuration is already decided at compile time
@ -2151,10 +2177,7 @@ def user_autotune(
if len(configs) == 0:
configs = [triton.Config({})]
else:
configs = [
triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c}))
for c in configs
]
configs = [*map(config_from_dict, configs)]
return cached_autotune(
None,
configs,
@ -2180,150 +2203,226 @@ def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
)
def grid(*numels):
"""Helper function to compute triton grids"""
if len(numels) == 1:
xnumel, ynumel, znumel = numels[0], None, None
elif len(numels) == 2:
xnumel, ynumel, znumel = numels[1], numels[0], None
elif len(numels) == 3:
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
else:
raise AssertionError(f"invalid size for numels {len(numels)}")
@dataclasses.dataclass
class GridExpr:
"""Generate code for grid size expressions in launcher"""
def get_grid_dim(numel, block):
if numel is None:
return 1
if block is None:
inductor_meta: dict[str, Any]
mode: Literal["python", "cpp"] = "python"
prefix: Sequence[str] = ()
x_grid: Union[str, int] = 1
y_grid: Union[str, int] = 1
z_grid: Union[str, int] = 1
def __post_init__(self) -> None:
assert self.mode in ("python", "cpp")
def generate(self, meta: dict[str, int]) -> None:
raise NotImplementedError
def ceildiv(
self, numel: Union[str, int], block: Union[None, int, str]
) -> Union[str, int]:
if block is None or block == 1:
return numel
return ceildiv(numel, block)
if isinstance(numel, int) and isinstance(block, int):
return ceildiv(numel, block) # constant fold
if self.mode == "python":
return f"-(({numel}) // -({block}))"
# trick above doesn't work in C++ due to rounding differences
return f"(({numel} + ({block} - 1)) / ({block}))"
def grid_fn(meta):
x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]:
"""Codegen for max function with constant folding, constants are represented as int"""
items = self._constant_fold(max, seq)
if len(items) <= 1:
return items[0]
if self.mode == "python":
return f"max({', '.join(map(str, items))})"
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
max_y_grid = get_max_y_grid()
if znumel is None:
div = ceildiv(y_grid, max_y_grid)
y_grid = ceildiv(y_grid, div)
z_grid = div
else:
z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
torch._check(
y_grid <= max_y_grid,
lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
)
def summation(self, seq: list[Union[int, str]]) -> Union[int, str]:
"""Codegen for sum function with constant folding, constants are represented as int"""
items = self._constant_fold(sum, seq)
if len(items) <= 1:
return items[0]
return " + ".join(map(str, items))
return (
x_grid,
y_grid,
z_grid,
def _constant_fold(
self, fn: Callable[[list[int]], int], seq: list[Union[int, str]]
) -> list[Union[int, str]]:
"""Constant fold through a commutative fn where ints are constants"""
items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)]
const_items = [x for x in seq if isinstance(x, int)]
if const_items:
items.append(fn(const_items))
return items
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
# Grid functions are one per kernel, so name collisions are fine
if self.mode == "python":
return f"{name} = {expr}"
if self.mode == "cpp":
return f"uint32_t {name} = {expr};"
raise AssertionError(f"invalid mode {self.mode}")
@staticmethod
def from_meta(
inductor_meta: dict[str, Any],
cfg: Union[Config, dict[str, int]],
mode: Literal["python", "cpp"] = "python",
) -> GridExpr:
grid_cls = globals()[inductor_meta["grid_type"]]
assert issubclass(grid_cls, GridExpr)
grid = grid_cls(inductor_meta=inductor_meta, mode=mode)
if isinstance(cfg, Config):
cfg = config_to_dict(cfg)
grid.generate(cfg)
return grid
def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]:
scope = {**meta}
for line in self.prefix:
exec(line, scope)
exec(f"grid_0 = {self.x_grid}", scope)
exec(f"grid_1 = {self.y_grid}", scope)
exec(f"grid_2 = {self.z_grid}", scope)
return scope["grid_0"], scope["grid_1"], scope["grid_2"]
class Grid1D(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
class Grid2D(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
class Grid3D(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK"))
class Grid2DWithYZOverflow(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
self.prefix = [
self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))),
self.assign_tmp(
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
),
]
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
self.z_grid = "y_grid_div_"
class CooperativeReductionGrid(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = str(meta["RSPLIT"])
self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
class SplitScanGrid(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
assert meta.get("XBLOCK", 1) == 1
self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK"))
self.y_grid = "xnumel"
class FixedGrid(GridExpr):
@staticmethod
def setup_grid_as_args() -> dict[str, Any]:
"""Inductor meta so the launcher takes three extra grid arguments"""
return {
"grid_type": FixedGrid.__name__,
"fixed_grid": ["_grid_0", "_grid_1", "_grid_2"],
"extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"],
}
def generate(self, meta: dict[str, int]) -> None:
self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"]
class PrecomputedGrid(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
for candidate in self.inductor_meta["precomputed_grids"]:
if all(meta.get(k) == v for k, v in candidate["config"].items()):
self.x_grid, self.y_grid, self.z_grid = candidate[self.mode]
return
raise AssertionError(
f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}"
)
setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010
return grid_fn
def cooperative_reduction_grid(xnumel):
def grid_fn(meta):
return (meta["RSPLIT"], ceildiv(xnumel, meta.get("XBLOCK", 1)), 1)
grid_fn_str = f"cooperative_reduction_grid({xnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def maybe_cooperative_reduction_grid(xnumel):
def grid_fn(meta):
if "RSPLIT" in meta:
return coop_grid(meta)
return normal_grid(meta)
coop_grid = cooperative_reduction_grid(xnumel)
normal_grid = grid(xnumel)
grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def split_scan_grid(xnumel, rnumel):
def grid_fn(meta):
assert meta.get("XBLOCK", 1) == 1
return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def grid_combo_kernels(
*numels, num_kernels, min_blocks, is_sequential, default_meta=None
):
"""min_blocks is the minimal size of the grid x dimension"""
if not is_sequential:
# round robin dispatch
numels_agg = list(numels)
for i in range(len(numels_agg)):
if isinstance(numels_agg[i], (list, tuple)):
numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301
kernel_grid_fn = grid(*numels_agg)
if isinstance(numels[-1], (list, tuple)):
min_blocks_d = max(-min(numels[-1]), 0) * num_kernels
else:
min_blocks_d = None
if min_blocks is None:
assert min_blocks_d is not None
min_blocks = min_blocks_d
else:
assert min_blocks_d is None or min_blocks == min_blocks_d, (
f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
class ComboKernelGrid(GridExpr):
def generate(self, meta: dict[str, int]):
combo_meta = self.inductor_meta["combo_grid_meta"]
if combo_meta["default_config"]:
meta = {**combo_meta["default_config"], **meta}
no_x_dims = []
xnumels = []
ynumels = []
znumels = []
for num in range(combo_meta["num_kernels"]):
assert (
combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0
)
else:
# sequential dispatch
seq_numels = list(numels)
# x numels are not used here, just a place holder
seq_numels[-1] = 1024
for i in range(len(seq_numels) - 1):
if isinstance(seq_numels[i], (list, tuple)):
seq_numels[i] = max(seq_numels[i])
no_x_dims.append(combo_meta[f"no_x_dim_{num}"])
xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}")
if f"ynumel_{num}" in combo_meta:
ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}")
if f"znumel_{num}" in combo_meta:
znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}")
kernel_grid_fn = grid(*seq_numels)
self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta)
if combo_meta["min_blocks"]:
self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]])
if ynumels:
self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK"))
if znumels:
self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK"))
def get_grid_dim(numel, block):
if numel is None:
return 1
if block is None:
return numel
return ceildiv(numel, block)
def combo_x_grid(
self,
xnumels: list[Union[int, str]],
no_x_dims: list[bool],
meta: dict[str, int],
) -> Union[str, int]:
raise NotImplementedError
def grid_fn(meta):
assert min_blocks is not None, "min_blocks must be a number"
cuda_grid = list(kernel_grid_fn(meta))
cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
return tuple(cuda_grid)
def seq_grid_fn(meta):
cuda_grid = list(kernel_grid_fn(meta))
# x <= 0 means this kernel's x grid is not tunable (x_no_dim is true)
x_grid = sum(
class SequentialComboKernelGrid(ComboKernelGrid):
def combo_x_grid(
self,
xnumels: list[Union[int, str]],
no_x_dims: list[bool],
meta: dict[str, int],
) -> Union[str, int]:
assert len(xnumels) == len(no_x_dims)
return self.summation(
[
-x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1))
for x in numels[-1]
self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK"))
for x, no_x_dim in zip(xnumels, no_x_dims)
]
)
cuda_grid[0] = x_grid
return tuple(cuda_grid)
def grid_fn_default_meta(meta):
return grid_fn(default_meta)
def seq_grid_fn_default_meta(meta):
return seq_grid_fn(default_meta)
if default_meta is None:
return grid_fn if not is_sequential else seq_grid_fn
else:
return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta
class RoundRobinComboKernelGrid(ComboKernelGrid):
def combo_x_grid(
self,
xnumels: list[Union[int, str]],
no_x_dims: list[bool],
meta: dict[str, int],
) -> str:
assert len(xnumels) == len(no_x_dims)
num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"]
exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim]
xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim]
if xnumels_x_dim:
exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK")))
return f"({self.maximum(exprs)}) * {num_kernels}"

View File

@ -3863,7 +3863,7 @@ class Scheduler:
for name in sorted(
self.buffer_names_to_free
- V.graph.removed_buffers
- V.graph.wrapper_code.freed
- V.graph.wrapper_code.freed # type: ignore[has-type]
):
if name in self.name_to_buf:
buf = self.name_to_buf[name]
@ -4122,7 +4122,7 @@ class Scheduler:
V.graph.wrapper_code.define_subgraph_launcher_fn(partition_code)
V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature)
V.graph.wrapper_code.allocated.update(
V.graph.wrapper_code.allocated.update( # type: ignore[has-type]
[node.get_name() for node in signature.output_nodes]
)

View File

@ -31,6 +31,7 @@ from torch._inductor.utils import clear_on_fresh_inductor_cache
from torch.utils._filelock import FileLock
from torch.utils._ordered_set import OrderedSet
from ..utils._sympy.functions import CeilDiv
from . import config, ir
from .autotune_process import (
TensorMeta,
@ -54,12 +55,15 @@ from .codegen.triton import (
TritonScheduling,
)
from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta
from .codegen.wrapper import pexpr
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .ops_handler import StoreMode
from .runtime.benchmarking import benchmarker
from .runtime.hints import DeviceProperties
from .runtime.triton_heuristics import FixedGrid
from .utils import (
ceildiv,
FakeIndentedBuffer,
get_dtype_size,
is_gpu,
@ -442,6 +446,7 @@ class TritonTemplateKernel(TritonKernel):
inductor_meta = {
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
**TritonKernel.inductor_meta_common(),
**FixedGrid.setup_grid_as_args(),
}
if config.profile_bandwidth or config.benchmark_kernel:
num_gb = self.estimate_kernel_num_bytes() / 1e9
@ -988,46 +993,44 @@ class TritonTemplateKernel(TritonKernel):
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
# Handle workspace allocation
if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
if V.graph.cpp_wrapper:
# In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
# if any dynamic dimension is involved. We rely on the Python version
# of the grid function to generate those grid configs, which may contain
# symbolic values. The wrapper will use cexpr to print out C++ code
# appropriately for the grid configs.
grid = self.call_sizes + [self.meta]
wrapper.generate_kernel_call(
name,
call_args,
grid=self.grid_fn(*grid),
# Calling self.grid_fn(*grid) already computes grid as a tuple,
# so we need to explicitly set grid_fn as empty here. Otherwise, the
# generated wrapper code will wrap the tuple as grid(tuple), which can
# cause incorrect grid computation in some corner cases.
grid_fn="",
arg_types=arg_types,
triton_meta=self.triton_meta,
)
grid_args = ()
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
else:
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
meta = wrapper.add_meta_once(self.meta)
grid = self.call_sizes + [meta]
wrapper.generate_kernel_call(
name,
call_args,
grid=grid,
grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
arg_types=arg_types,
triton_meta=self.triton_meta,
gpu="cpu" not in V.graph.device_types,
fn_name = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}"
call_args.append(
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
)
arg_types.append(None)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
call_args.extend(grid_args)
arg_types.extend(map(type, grid_args))
if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
wrapper.generate_kernel_call(
name,
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)
if self.workspace_arg is not None:
wrapper.generate_workspace_deallocation(self.workspace_arg)
def kernel_benchmark_extra_args(self) -> list[str]:
return [
str(x)
for x in self.grid_fn(
*V.graph.sizevars.size_hints(self.call_sizes), self.meta
)
]
@functools.lru_cache(None)
def _jinja2_env():
@ -1210,8 +1213,7 @@ class TritonTemplate(KernelTemplate):
module_path=mod.__file__,
module_cache_key=mod.key,
kernel_name=kernel_name,
grid=grid,
extra_args=extra_args,
extra_args=[*extra_args, *grid],
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
@ -1328,7 +1330,6 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
self.log_info.update(
{
"backend": "Triton",
"grid": str(self.bmreq.grid),
"num_stages": self.bmreq.num_stages,
"num_warps": self.bmreq.num_warps,
}
@ -2352,5 +2353,35 @@ def realize_inputs(*args):
return [realize_inputs(x) for x in args]
class SymbolicGridFn:
"""
Wrapper around a grid function that allows either int or sympy inputs.
@SymbolicGridFn
def grid(x, meta, *, cdiv):
return cdiv(x, meta["BLOCK_X"])
"""
def __init__(self, fn: Callable[..., tuple[Any, Any, Any]]):
self.fn = fn
self.kwargs_int = {}
self.kwargs_sym = {}
params = inspect.signature(fn).parameters
for name, fn_sym, fn_int in [
("cdiv", CeilDiv, ceildiv),
("min", sympy.Min, min),
("max", sympy.Max, max),
]:
if name in params:
self.kwargs_int[name] = fn_int
self.kwargs_sym[name] = fn_sym
def __call__(self, *args, **kwargs) -> tuple[int, int, int]:
return self.fn(*args, **kwargs, **self.kwargs_int)
def sympy_call(self, *args, **kwargs):
return self.fn(*args, **kwargs, **self.kwargs_sym)
# ensure lowering is imported so that `extern_kernels.*` is populated
from . import lowering # noqa: F401

View File

@ -45,23 +45,23 @@ def rename_kernels(source_code: str) -> str:
def merge_params(original_params: list[str], new_params: list[str]) -> list[str]:
assert len(new_params) >= len(original_params)
for idx in range(len(new_params)):
if new_params[idx] == "T":
new_params[idx] = original_params[idx]
return new_params
def add_launch_params(original: str, kernel_to_params: dict[str, str]) -> str:
def add_launch_params(
original: str, kernel_to_params: dict[str, tuple[str, str]]
) -> str:
# Regex to match the function call in the original string
pattern = r"(\w+)\.run\((.*), grid=(.*\)), [^)]*\)"
pattern = r"(\w+)\.run\((.*)\)"
def replace(match) -> str:
# Extract parts from the regex match
func_name = match.group(1)
params = match.group(2)
grid = match.group(3)
new_params = kernel_to_params[func_name]
new_params, grid = kernel_to_params[func_name]
new_params = merge_params(params.split(", "), new_params.split(", "))
# Format the new function call
@ -103,9 +103,8 @@ def process_file(input_filename: str, output_filename: str) -> str:
launch_params_meta = f.readlines()
split_params = [i.split("|") for i in launch_params_meta]
strip_params = [[a.strip(), b.strip()] for a, b in split_params]
kernel_to_args: dict[str, str] = dict(strip_params)
transformed_code = add_launch_params(transformed_code, kernel_to_args)
kernel_args_grid = {a.strip(): (b.strip(), c.strip()) for a, b, c in split_params}
transformed_code = add_launch_params(transformed_code, kernel_args_grid)
with open(output_filename, "w") as file:
file.write(transformed_code)