mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[import][inductor] Simplify grid handling (#147583)
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Note the attached diff contains some minor fbcode-only changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147583
Approved by: https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
parent
dae3fbfe97
commit
b59776d857
|
|
@ -4086,10 +4086,9 @@ class AOTInductorTestsTemplate:
|
|||
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
|
||||
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
|
||||
expected_scalar_args = [
|
||||
"int64_t var_1 = u0;",
|
||||
"int64_t var_4 = u0;",
|
||||
"int64_t var_7 = u0;",
|
||||
"int64_t var_12 = u0;",
|
||||
"buf3, u0",
|
||||
"buf4, u0",
|
||||
"buf3, buf4, buf2, u0",
|
||||
]
|
||||
# check the new behavior of codegen is expected
|
||||
result, code = run_and_get_cpp_code(
|
||||
|
|
|
|||
|
|
@ -54,22 +54,6 @@ class TestCppWrapperHipify(TestCase):
|
|||
} \\
|
||||
} while (0);
|
||||
|
||||
namespace {
|
||||
|
||||
struct Grid {
|
||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||
uint32_t grid_x;
|
||||
uint32_t grid_y;
|
||||
uint32_t grid_z;
|
||||
|
||||
bool is_non_zero() {
|
||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static inline hipFunction_t loadKernel(
|
||||
std::string filePath,
|
||||
const std::string &funcName,
|
||||
|
|
|
|||
|
|
@ -550,7 +550,7 @@ class CudaReproTests(TestCase):
|
|||
"""
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
from torch._inductor.runtime.hints import AttrsDescriptorWrapper, HeuristicType
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
def autotune(configs, meta):
|
||||
def decorator(fn):
|
||||
|
|
@ -564,6 +564,7 @@ class CudaReproTests(TestCase):
|
|||
reset_to_zero_arg_names=[],
|
||||
optimize_mem=True,
|
||||
heuristic_type=HeuristicType.POINTWISE,
|
||||
inductor_meta={"grid_type": "Grid1D"},
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
|
@ -603,8 +604,8 @@ class CudaReproTests(TestCase):
|
|||
inout2 = inout1.clone()
|
||||
|
||||
stream0 = get_cuda_stream(0)
|
||||
kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)
|
||||
kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)
|
||||
kernel.run(inout1, in0, xnumel, stream=stream0)
|
||||
kernel.run(inout2, in0, xnumel, stream=stream0)
|
||||
|
||||
assert same(
|
||||
inout1, inout2, tol=0.001, equal_nan=True
|
||||
|
|
|
|||
|
|
@ -59,11 +59,16 @@ class TestKernelBenchmark(TestCase):
|
|||
def verify_compiled_kernels(self, GB_count=1):
|
||||
compiled_module = self.get_compiled_module()
|
||||
# now run the compiled module in subprocess and check its output
|
||||
bench_out = subprocess.check_output(
|
||||
f"{sys.executable} {compiled_module.__file__} -kc".split(),
|
||||
stderr=subprocess.STDOUT,
|
||||
env={**os.environ, "PYTHONPATH": self.python_path},
|
||||
).decode()
|
||||
try:
|
||||
bench_out = subprocess.check_output(
|
||||
f"{sys.executable} {compiled_module.__file__} -kc".split(),
|
||||
stderr=subprocess.STDOUT,
|
||||
env={**os.environ, "PYTHONPATH": self.python_path},
|
||||
).decode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Failed when running output code", e)
|
||||
print(e.output.decode())
|
||||
raise e
|
||||
|
||||
# make sure we have the bandwidth information in the output
|
||||
FileCheck().check_count(
|
||||
|
|
@ -112,11 +117,16 @@ class TestKernelBenchmark(TestCase):
|
|||
|
||||
def check_bandwidth(self, compiled_module, num_gb):
|
||||
# now run the compiled module in subprocess and check its output
|
||||
bench_out = subprocess.check_output(
|
||||
f"{sys.executable} {compiled_module.__file__} -k".split(),
|
||||
stderr=subprocess.STDOUT,
|
||||
env={**os.environ, "PYTHONPATH": self.python_path},
|
||||
).decode()
|
||||
try:
|
||||
bench_out = subprocess.check_output(
|
||||
f"{sys.executable} {compiled_module.__file__} -k".split(),
|
||||
stderr=subprocess.STDOUT,
|
||||
env={**os.environ, "PYTHONPATH": self.python_path},
|
||||
).decode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Failed when running output code", e)
|
||||
print(e.output.decode())
|
||||
raise e
|
||||
|
||||
# make sure we have the bandwidth information in the output
|
||||
FileCheck().check_count(
|
||||
|
|
@ -156,7 +166,7 @@ class TestKernelBenchmark(TestCase):
|
|||
self.verify_compiled_kernels()
|
||||
|
||||
@config.patch(
|
||||
max_autotune=True, max_autotune_gemm_backends="TRITON", force_shape_pad=True
|
||||
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
|
||||
)
|
||||
@fresh_inductor_cache()
|
||||
def test_mm_triton_kernel_benchmark(self):
|
||||
|
|
@ -175,28 +185,7 @@ class TestKernelBenchmark(TestCase):
|
|||
|
||||
f(a, b)
|
||||
|
||||
GB_count = 3
|
||||
# pad_mm is not enabled on XPU, so there is only one kernel.
|
||||
if GPU_TYPE == "xpu":
|
||||
GB_count = 1
|
||||
self.verify_compiled_kernels(GB_count=GB_count)
|
||||
|
||||
# make sure we correctly generate the grid info
|
||||
compiled_module = self.get_compiled_module()
|
||||
with open(compiled_module.__file__) as f:
|
||||
source_code = f.read()
|
||||
lines = source_code.split("\n")
|
||||
meta = [l for l in lines if "meta0 = {" in l]
|
||||
scope = {}
|
||||
from torch._inductor.kernel.mm_common import mm_grid
|
||||
|
||||
exec(meta[0], scope)
|
||||
grid = mm_grid(M, N, scope["meta0"])
|
||||
FileCheck().check_count(
|
||||
f"grid={grid}",
|
||||
2,
|
||||
exactly=1,
|
||||
).run(source_code)
|
||||
self.verify_compiled_kernels(GB_count=1)
|
||||
|
||||
def test_matmul_bandwidth_computation(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def _get_func_call() -> str:
|
|||
|
||||
|
||||
def _get_kernel_launch() -> str:
|
||||
return "launchKernel(" if config.cpp_wrapper else ".run("
|
||||
return "call_triton_" if config.cpp_wrapper else ".run("
|
||||
|
||||
|
||||
def benchmark_choice(choice, args, out, expected_out, timings):
|
||||
|
|
|
|||
|
|
@ -265,9 +265,6 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
|
|||
self.assertEqual(args["kernel_backend"], "triton", msg=f"event = {e}")
|
||||
|
||||
self.assertTrue("stream" in args, msg=f"event = {e}")
|
||||
self.assertTrue("grid" in args, msg=f"event = {e}")
|
||||
self.assertTrue(args["grid"].startswith("grid"), msg=f"event = {e}")
|
||||
|
||||
self.assertTrue("kernel_file" in args, msg=f"event = {e}")
|
||||
kernel_file = args["kernel_file"]
|
||||
self.assertTrue(os.path.isfile(kernel_file), msg=f"event = {e}")
|
||||
|
|
|
|||
|
|
@ -353,7 +353,6 @@ class TestSelectAlgorithm(TestCase):
|
|||
module_path=module_path,
|
||||
module_cache_key=None,
|
||||
kernel_name=None,
|
||||
grid=None,
|
||||
extra_args=None,
|
||||
num_stages=None,
|
||||
num_warps=None,
|
||||
|
|
|
|||
|
|
@ -3985,7 +3985,7 @@ class CommonTemplate:
|
|||
_, code = run_and_get_code(foo, grouped_conv, input_tensor)
|
||||
# no to channels last permuting before kernel
|
||||
if config.cpp_wrapper:
|
||||
FileCheck().check_not("launchKernel(triton").check("_convolution(").run(
|
||||
FileCheck().check_not(" call_triton").check("_convolution(").run(
|
||||
code[0]
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3635,8 +3635,10 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
|||
output = "\n".join(record.getMessage() for record in log.records)
|
||||
# correct grid example values updated per block size
|
||||
FileCheck().check("Compile-time auto-tuning block:").check(
|
||||
"grid_wrapper_for_op_zeros_0"
|
||||
).check_next("return (256").check_next("return (64").run(output)
|
||||
"PrecomputedGrid"
|
||||
).check("(31 + _launcher_s0) // 32").check("(127 + _launcher_s0) // 128").run(
|
||||
output
|
||||
)
|
||||
|
||||
# Triton 3.2.0 adds the required flags to the Autotuner object for this test
|
||||
# PR: https://github.com/triton-lang/triton/pull/5092
|
||||
|
|
|
|||
|
|
@ -639,7 +639,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
extra_args: Iterable[Any],
|
||||
module_path: str, # the path of the module defining the triton kernel
|
||||
module_cache_key: str,
|
||||
grid: list[int],
|
||||
num_stages: int,
|
||||
num_warps: int,
|
||||
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
||||
|
|
@ -648,7 +647,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
||||
self.module_path = module_path
|
||||
self.module_cache_key = module_cache_key
|
||||
self.grid = grid
|
||||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
||||
|
|
@ -700,16 +698,15 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
)
|
||||
|
||||
# Handle zero initialization if needed
|
||||
if workspace_arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL:
|
||||
if workspace_arg.zero_mode != WorkspaceZeroMode.UNINITIALIZED:
|
||||
workspace_tensor.zero_()
|
||||
|
||||
# Run the kernel with workspace
|
||||
run_method(
|
||||
*input_tensors,
|
||||
output_tensor,
|
||||
*extra_args,
|
||||
workspace_tensor,
|
||||
grid=self.grid,
|
||||
*extra_args,
|
||||
**warmup_arg,
|
||||
stream=stream,
|
||||
benchmark_run=True,
|
||||
|
|
@ -725,7 +722,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
*input_tensors,
|
||||
output_tensor,
|
||||
*extra_args,
|
||||
grid=self.grid,
|
||||
**warmup_arg,
|
||||
stream=stream,
|
||||
)
|
||||
|
|
@ -735,7 +731,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
*input_tensors,
|
||||
output_tensor,
|
||||
*extra_args,
|
||||
grid=self.grid,
|
||||
**warmup_arg,
|
||||
stream=stream,
|
||||
benchmark_run=True,
|
||||
|
|
|
|||
|
|
@ -1358,7 +1358,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
|||
|
||||
@clear_on_fresh_inductor_cache
|
||||
class CudaKernelParamCache:
|
||||
cache: dict[str, dict[str, str]] = {}
|
||||
cache: dict[str, dict[str, Any]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1376,7 +1376,7 @@ class CudaKernelParamCache:
|
|||
cls.cache[key] = params
|
||||
|
||||
@classmethod
|
||||
def get(cls, key: str) -> Optional[dict[str, str]]:
|
||||
def get(cls, key: str) -> Optional[dict[str, Any]]:
|
||||
return cls.cache.get(key, None)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1550,7 +1550,7 @@ class KernelArgs:
|
|||
|
||||
def python_argdefs(
|
||||
self,
|
||||
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[torch.dtype]]:
|
||||
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
|
||||
arg_defs: list[ArgName] = []
|
||||
call_args: list[str] = []
|
||||
arg_types: list[torch.dtype] = []
|
||||
|
|
|
|||
|
|
@ -5204,7 +5204,7 @@ class KernelGroup:
|
|||
def call_kernel(self, wrapper, kernel_name):
|
||||
_, call_args, arg_types = self.args.cpp_argdefs()
|
||||
wrapper.generate_kernel_call(
|
||||
kernel_name, call_args, gpu=False, triton=False, arg_types=arg_types
|
||||
kernel_name, call_args, triton=False, arg_types=arg_types
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -118,9 +118,7 @@ class CppTemplateKernel(CppKernel):
|
|||
def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, arg_types = self.args.cpp_argdefs()
|
||||
wrapper.generate_kernel_call(
|
||||
name, call_args, triton=False, gpu=False, arg_types=arg_types
|
||||
)
|
||||
wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types)
|
||||
|
||||
def dtype(self, node: ir.Buffer) -> str:
|
||||
return DTYPE_TO_CPP[node.get_dtype()]
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from ..virtualized import V
|
|||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import get_device_op_overrides, IndentedBuffer, Kernel
|
||||
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
|
||||
from .triton_utils import should_unwrap_unspec_arg
|
||||
from .wrapper import (
|
||||
EnterSubgraphLine,
|
||||
ExitSubgraphLine,
|
||||
|
|
@ -89,27 +88,20 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
grid=None,
|
||||
device_index=None,
|
||||
gpu=False,
|
||||
triton=False,
|
||||
*,
|
||||
device=None,
|
||||
triton=True,
|
||||
arg_types=None,
|
||||
raw_args=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
autotune_configs=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
"""
|
||||
Generates kernel call code.
|
||||
|
||||
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
|
||||
|
||||
triton: Defines whether the GPU backend uses Triton for codegen.
|
||||
Otherwise it uses the CUDA language for codegen.
|
||||
Only valid when cuda == True.
|
||||
"""
|
||||
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
"Mismatch call_args and arg_types in generate_kernel_call"
|
||||
)
|
||||
|
|
@ -853,27 +845,32 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
|
||||
def generate(self, is_inference):
|
||||
with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True):
|
||||
if V.graph.aot_mode and not V.graph.is_const_graph:
|
||||
self.codegen_model_kernels()
|
||||
self.codegen_model_constructor()
|
||||
self.codegen_const_run_driver()
|
||||
self.write_wrapper_decl()
|
||||
return super().generate(is_inference)
|
||||
|
||||
def finalize_prefix(self):
|
||||
cached_dtypes_buffer = IndentedBuffer()
|
||||
prior = self.prefix
|
||||
self.prefix = aot_mode_decls = IndentedBuffer()
|
||||
if V.graph.aot_mode and not V.graph.is_const_graph:
|
||||
aot_mode_decls.writeline("namespace torch::aot_inductor {")
|
||||
self.codegen_model_kernels()
|
||||
self.codegen_model_constructor()
|
||||
self.codegen_const_run_driver()
|
||||
aot_mode_decls.writeline("} // namespace torch::aot_inductor")
|
||||
aot_mode_decls.writeline("using namespace torch::aot_inductor;")
|
||||
|
||||
self.prefix = cache_decls = IndentedBuffer()
|
||||
for dtype in self.used_cached_dtypes:
|
||||
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
|
||||
cache_decls.writeline(f"CACHE_TORCH_DTYPE({dtype});")
|
||||
for device in self.used_cached_devices:
|
||||
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
|
||||
cache_decls.writeline(f"CACHE_TORCH_DEVICE({device});")
|
||||
for layout in self.used_cached_layouts:
|
||||
cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
|
||||
cache_decls.writeline(f"CACHE_TORCH_LAYOUT({layout});")
|
||||
for memory_format in self.used_cached_memory_formats:
|
||||
cached_dtypes_buffer.writeline(
|
||||
f"CACHE_TORCH_MEMORY_FORMAT({memory_format});"
|
||||
)
|
||||
cached_dtypes_buffer.splice(self.prefix)
|
||||
self.prefix = cached_dtypes_buffer
|
||||
cache_decls.writeline(f"CACHE_TORCH_MEMORY_FORMAT({memory_format});")
|
||||
|
||||
self.prefix.splice(aot_mode_decls)
|
||||
self.prefix.splice(prior)
|
||||
|
||||
def define_kernel(
|
||||
self,
|
||||
|
|
@ -1264,24 +1261,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
# it suffices as a type hint for the purposes of producing the correct code for this type.
|
||||
return SymbolicCallArg(expr, tree.numel)
|
||||
|
||||
def prepare_triton_kernel_call(self, device_index, call_args):
|
||||
def wrap_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
||||
return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg
|
||||
elif isinstance(arg, (int, float, bool, SymbolicCallArg)):
|
||||
return str(arg)
|
||||
else:
|
||||
return cexpr(V.graph.sizevars.simplify(arg))
|
||||
|
||||
call_args = [wrap_arg(arg) for arg in call_args]
|
||||
|
||||
if device_index is None:
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
device_index = current_device.index
|
||||
|
||||
return device_index, call_args
|
||||
|
||||
def codegen_dynamic_scalar(self, node):
|
||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||
self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")
|
||||
|
|
|
|||
|
|
@ -106,27 +106,21 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
grid=None,
|
||||
device_index=None,
|
||||
gpu=False,
|
||||
triton=False,
|
||||
*,
|
||||
device=None,
|
||||
triton=True,
|
||||
arg_types=None,
|
||||
raw_args=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
autotune_configs=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
"""
|
||||
Generates kernel call code.
|
||||
|
||||
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
|
||||
|
||||
triton: Defines whether the GPU backend uses Triton for codegen.
|
||||
Otherwise it uses the CUDA language for codegen.
|
||||
Only valid when cuda == True.
|
||||
"""
|
||||
assert not gpu, (
|
||||
assert not triton, (
|
||||
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
|
||||
)
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
|
|
|
|||
|
|
@ -1,188 +1,194 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from itertools import chain, count, zip_longest
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
from itertools import count, zip_longest
|
||||
from typing import Any, Optional, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import sympy
|
||||
|
||||
from torch import dtype as torch_dtype
|
||||
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
|
||||
|
||||
from .. import config
|
||||
from ..codecache import CudaKernelParamCache
|
||||
from ..ir import GraphPartitionSignature, IRNode, TensorBox
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
DeferredLineBase,
|
||||
get_gpu_type,
|
||||
GPU_ALIGN_BYTES,
|
||||
triton_version_uses_attrs_dict,
|
||||
)
|
||||
from ..ir import GraphPartitionSignature, TensorBox
|
||||
from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import get_device_op_overrides
|
||||
from .cpp_utils import cexpr
|
||||
from .cpp_wrapper_cpu import CppWrapperCpu
|
||||
from .multi_kernel import MultiKernelCall
|
||||
from .triton_utils import should_unwrap_unspec_arg
|
||||
from .wrapper import PythonWrapperCodegen, SymbolicCallArg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Hashable
|
||||
|
||||
from ..graph import GraphLowering
|
||||
_cpp_string_literal_escapes = {
|
||||
"\\": "\\\\",
|
||||
'"': '\\"',
|
||||
"\n": "\\n",
|
||||
"\t": "\\t",
|
||||
"\r": "\\r",
|
||||
}
|
||||
_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]')
|
||||
|
||||
|
||||
class DeferredGpuKernelLine(DeferredLineBase):
|
||||
def cpp_string_literal(s: str) -> str:
|
||||
escaped = _cpp_string_literal_pattern.sub(
|
||||
lambda match: _cpp_string_literal_escapes[match.group(0)], s
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeferredTritonCallWrapper:
|
||||
"""
|
||||
When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
|
||||
to be tuned and stored as cubin files, so use a deferred line to backfill those information
|
||||
to be tuned and stored as cubin files, so use a deferred generating the final wrapper around
|
||||
the triton kernel until right before the prefix is written.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
line_template: str,
|
||||
keys: tuple[str, ...],
|
||||
additional_files: list[str],
|
||||
):
|
||||
super().__init__(line_template)
|
||||
assert not isinstance(line_template, DeferredLineBase)
|
||||
self.additional_files = additional_files
|
||||
self.kernel_name = kernel_name
|
||||
self.line_template = line_template
|
||||
self.keys = keys
|
||||
wrapper_name: str
|
||||
kernel_name: str
|
||||
arg_types: list[Any]
|
||||
|
||||
def __call__(self):
|
||||
def generate(self, wrapper: CppWrapperGpu):
|
||||
prefix = wrapper.prefix
|
||||
if self.kernel_name.startswith("multi_kernel_"):
|
||||
# MultiKernel will select one kernel after running the autotune block
|
||||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert params is not None, (
|
||||
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
)
|
||||
for key in self.keys:
|
||||
assert key in params, (
|
||||
f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
|
||||
assert params, f"CudaKernelParamCache not populated for {self.kernel_name}"
|
||||
def_args = params["def_args"]
|
||||
arg_types = self.arg_types
|
||||
inductor_meta = params["inductor_meta"]
|
||||
|
||||
if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types):
|
||||
# extra_launcher_args should already be in def_args
|
||||
assert len(def_args) == len(arg_types) - len(
|
||||
inductor_meta["extra_launcher_args"]
|
||||
)
|
||||
arg_types = arg_types + [SymbolicCallArg] * len(
|
||||
inductor_meta["extra_launcher_args"]
|
||||
)
|
||||
if key == get_cpp_wrapper_cubin_path_name():
|
||||
assert os.path.exists(params[key]), f"{params[key]} does not exist"
|
||||
self.additional_files.append(params[key])
|
||||
|
||||
return self.line_template % tuple(params[key] for key in self.keys)
|
||||
if not V.graph.aot_mode:
|
||||
prefix.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;"
|
||||
)
|
||||
)
|
||||
kernel_var_name = self.kernel_name
|
||||
else:
|
||||
kernel_var_name = f"kernels_.{self.kernel_name}"
|
||||
|
||||
def _new_line(self, line):
|
||||
return DeferredGpuKernelLine(
|
||||
self.kernel_name, line, self.keys, self.additional_files
|
||||
# tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types
|
||||
template_types = [
|
||||
f"typename {name}_type_"
|
||||
for name, arg_type in zip(def_args, arg_types)
|
||||
if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg))
|
||||
]
|
||||
if V.graph.aot_mode:
|
||||
template_types.append("typename kernels_type_")
|
||||
if template_types:
|
||||
prefix.writeline(f"template <{', '.join(template_types)}>")
|
||||
prefix.writeline(f"static inline void {self.wrapper_name}(")
|
||||
with prefix.indent():
|
||||
assert len(def_args) == len(arg_types), (def_args, arg_types)
|
||||
for name, arg_type in zip(def_args, arg_types):
|
||||
if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)):
|
||||
prefix.writeline(f"const {name}_type_& {name},")
|
||||
elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)):
|
||||
prefix.writeline(f"int64_t {name},")
|
||||
elif arg_type is float:
|
||||
prefix.writeline(f"float {name},")
|
||||
elif arg_type is bool:
|
||||
prefix.writeline(f"bool {name},")
|
||||
else:
|
||||
raise ValueError(f"Unexpected arg type {arg_type}")
|
||||
prefix.writeline(f"{wrapper.device_codegen.cpp_stream_type()} stream_,")
|
||||
if V.graph.aot_mode:
|
||||
prefix.writeline("kernels_type_& kernels_,")
|
||||
prefix.writeline(
|
||||
"const std::optional<std::string>& cubin_dir_ = std::nullopt"
|
||||
)
|
||||
prefix.writeline("){")
|
||||
with prefix.indent():
|
||||
self.generate_grid(prefix, inductor_meta, params)
|
||||
self.generate_load_kernel(prefix, kernel_var_name, params)
|
||||
self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params)
|
||||
prefix.writeline("}")
|
||||
# Ensure the cubin file is included in the package
|
||||
V.graph.wrapper_code.additional_files.append(
|
||||
params[get_cpp_wrapper_cubin_path_name()]
|
||||
)
|
||||
|
||||
|
||||
class DeferredGpuDefaultGrid:
|
||||
"""
|
||||
A container for the default grid, which may be used by DeferredGpuGridLine
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def generate_grid(
|
||||
self,
|
||||
kernel_name: str,
|
||||
grid,
|
||||
grid_callable: Optional[Callable[..., Any]] = None,
|
||||
**grid_extra_kwargs,
|
||||
prefix: IndentedBuffer,
|
||||
inductor_meta: dict[str, Any],
|
||||
params: dict[str, Any],
|
||||
):
|
||||
self.kernel_name = kernel_name
|
||||
self.grid = grid
|
||||
self.grid_callable = grid_callable
|
||||
self.grid_extra_kwargs = grid_extra_kwargs
|
||||
from ..runtime.triton_heuristics import GridExpr
|
||||
|
||||
def __iter__(self):
|
||||
# DeferredGpuDefaultGrid can be passed to the base class, PythonWrapperCodegen,
|
||||
# to generate the autotune code block, and thus we need this iterator
|
||||
return iter(self.grid)
|
||||
|
||||
def _process_grid(self, grid: Union[list[Any], tuple[Any, ...]]):
|
||||
if isinstance(grid, (list, tuple)):
|
||||
return [self._process_grid(e) for e in grid]
|
||||
else:
|
||||
return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid
|
||||
|
||||
def __call__(self):
|
||||
if self.kernel_name.startswith("multi_kernel_"):
|
||||
# MultiKernel will select one kernel after running the autotune block
|
||||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
|
||||
grid = self.grid
|
||||
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
|
||||
grid = self._process_grid(grid)
|
||||
assert self.grid_callable is not None, "grid_callable can't be None"
|
||||
if not self.grid_extra_kwargs:
|
||||
grid_fn = self.grid_callable(*grid)
|
||||
else:
|
||||
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
|
||||
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert params is not None, (
|
||||
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp")
|
||||
for line in grid.prefix:
|
||||
prefix.writeline(line)
|
||||
prefix.splice(
|
||||
f"""\
|
||||
uint32_t grid_0 = {grid.x_grid};
|
||||
uint32_t grid_1 = {grid.y_grid};
|
||||
uint32_t grid_2 = {grid.z_grid};
|
||||
"""
|
||||
)
|
||||
return grid_fn(params["meta"])
|
||||
prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;")
|
||||
|
||||
def generate_load_kernel(self, prefix, kernel_var_name, params):
|
||||
prefix.writeline(f"if ({kernel_var_name} == nullptr) {{")
|
||||
with prefix.indent():
|
||||
load_kernel_args = [
|
||||
cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]),
|
||||
cpp_string_literal(params["mangled_name"]),
|
||||
str(params["shared_mem"]),
|
||||
"cubin_dir_",
|
||||
]
|
||||
prefix.writeline(
|
||||
f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); "
|
||||
)
|
||||
prefix.writeline("}")
|
||||
|
||||
class DeferredGpuGridLine(DeferredLineBase):
|
||||
"""
|
||||
When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels
|
||||
to be tuned and stored as cubin files, so use a deferred line to backfill those information
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
grid_var: str,
|
||||
grid,
|
||||
autotune_configs,
|
||||
):
|
||||
super().__init__("")
|
||||
self.kernel_name = kernel_name
|
||||
self.grid_var = grid_var
|
||||
self.grid = grid
|
||||
self.autotune_configs = autotune_configs
|
||||
|
||||
def __call__(self):
|
||||
if self.kernel_name.startswith("multi_kernel_"):
|
||||
# MultiKernel will select one kernel after running the autotune block
|
||||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert params is not None, (
|
||||
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params):
|
||||
triton_meta = params["triton_meta"]
|
||||
assert len(self.arg_types) == len(params["def_args"]), (
|
||||
self.arg_types,
|
||||
params["def_args"],
|
||||
)
|
||||
|
||||
if self.autotune_configs is not None:
|
||||
# This indicates the Triton kernel is a user-defined one.
|
||||
grid = None
|
||||
if len(self.grid) == 1:
|
||||
grid = self.grid[0]
|
||||
else:
|
||||
for i, c in enumerate(self.autotune_configs):
|
||||
if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
|
||||
grid = self.grid[i]
|
||||
break
|
||||
assert grid is not None
|
||||
elif isinstance(self.grid, DeferredGpuDefaultGrid):
|
||||
grid = self.grid()
|
||||
else:
|
||||
grid = self.grid
|
||||
|
||||
assert len(grid) != 0, "Grid can't be empty"
|
||||
grid_args_str = ", ".join(
|
||||
[cexpr(V.graph.sizevars.simplify(item)) for item in grid]
|
||||
)
|
||||
return f" Grid {self.grid_var} = Grid({grid_args_str});"
|
||||
|
||||
def _new_line(self, line):
|
||||
return DeferredGpuGridLine(
|
||||
self.kernel_name, self.grid_var, self.grid, self.autotune_configs
|
||||
arg_type_loookup = dict(zip(params["def_args"], self.arg_types))
|
||||
# difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants
|
||||
call_args = [
|
||||
name for name in params["call_args"] if name not in triton_meta["constants"]
|
||||
]
|
||||
arg_types = [arg_type_loookup[name] for name in call_args]
|
||||
arg_signatures = [triton_meta["signature"][name] for name in call_args]
|
||||
call_args_str = wrapper.generate_args_decl(
|
||||
prefix, call_args, arg_types, arg_signatures
|
||||
)
|
||||
prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};")
|
||||
launch_kernel_args = [
|
||||
kernel_var_name,
|
||||
"grid_0",
|
||||
"grid_1",
|
||||
"grid_2",
|
||||
str(params["num_warps"]),
|
||||
str(params["shared_mem"]),
|
||||
"kernel_args_",
|
||||
"stream_",
|
||||
]
|
||||
prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});")
|
||||
|
||||
|
||||
class CppWrapperGpu(CppWrapperCpu):
|
||||
|
|
@ -195,7 +201,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
self.device_codegen = get_device_op_overrides(self.device)
|
||||
super().__init__()
|
||||
self.grid_id = count()
|
||||
self._load_kernel_cache: dict[Hashable, str] = {}
|
||||
self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {}
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
|
|
@ -293,77 +299,25 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
|
||||
def generate(self, is_inference):
|
||||
with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
|
||||
self.prefix.writeline("\n")
|
||||
if not V.graph.aot_mode:
|
||||
for kernel in chain(
|
||||
sorted(self.src_to_kernel.values()),
|
||||
sorted(
|
||||
[entry[0] for entry in self.user_defined_kernel_cache.values()]
|
||||
),
|
||||
):
|
||||
self.prefix.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;"
|
||||
)
|
||||
)
|
||||
self.prefix.writeline("\n")
|
||||
return super().generate(is_inference)
|
||||
|
||||
def generate_user_defined_triton_kernel(
|
||||
self,
|
||||
kernel_name: str,
|
||||
raw_args: list[Any],
|
||||
grid: list[Any],
|
||||
configs,
|
||||
triton_meta,
|
||||
constexprs,
|
||||
):
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and kernel_name not in self.kernel_autotune_names
|
||||
):
|
||||
# Call PythonWrapperCodegen to create the autotune code block
|
||||
PythonWrapperCodegen.generate_user_defined_triton_kernel(
|
||||
self,
|
||||
kernel_name,
|
||||
raw_args,
|
||||
grid,
|
||||
configs,
|
||||
triton_meta,
|
||||
constexprs,
|
||||
)
|
||||
|
||||
if not triton_version_uses_attrs_dict():
|
||||
# in C++ wrapper, we don't pass constexpr args, as they don't
|
||||
# get added as parameters to the PTX code compiled from the
|
||||
# user-defined Triton kernel (only non-constexpr args do)
|
||||
raw_args = [
|
||||
raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
|
||||
]
|
||||
args = [self.val_to_arg_str(v) for v in raw_args]
|
||||
arg_types = [
|
||||
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
|
||||
for arg in raw_args
|
||||
]
|
||||
|
||||
# Call self.generate_kernel_call to generate the real kernel call in cpp
|
||||
self.generate_kernel_call(
|
||||
kernel_name,
|
||||
args,
|
||||
arg_types=arg_types,
|
||||
raw_args=raw_args,
|
||||
grid=grid,
|
||||
gpu=True,
|
||||
triton=True,
|
||||
triton_meta=triton_meta,
|
||||
autotune_configs=configs,
|
||||
)
|
||||
def finalize_prefix(self):
|
||||
"""Define the triton kernels now that autotuning is finished"""
|
||||
old_prefix = self.prefix # new content should go at start of prefix
|
||||
self.prefix = IndentedBuffer()
|
||||
super().finalize_prefix()
|
||||
for kernel in self._triton_call_wrappers.values():
|
||||
self.prefix.writeline("\n")
|
||||
kernel.generate(self)
|
||||
self.prefix.writeline("\n")
|
||||
self.prefix.splice(old_prefix)
|
||||
|
||||
def generate_tma_descriptor(self, desc):
|
||||
self.write_tma_descriptor_helpers_once()
|
||||
|
||||
# generate data pointer for the source tensor
|
||||
source = self.generate_args_decl(
|
||||
code=self,
|
||||
call_args=[self.val_to_arg_str(desc.tensor)],
|
||||
arg_types=[desc.tensor.get_dtype()],
|
||||
arg_signatures=[None],
|
||||
|
|
@ -384,36 +338,9 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}"
|
||||
self.writeline(f"{fn}({args});")
|
||||
|
||||
def generate_load_kernel_once(
|
||||
self,
|
||||
kernel_name: str,
|
||||
graph: "GraphLowering", # for per-graph caching
|
||||
def generate_args_decl(
|
||||
self, code: Union[IndentedBuffer, Self], call_args, arg_types, arg_signatures
|
||||
):
|
||||
cache_key = (kernel_name, graph)
|
||||
if cache_key in self._load_kernel_cache:
|
||||
return self._load_kernel_cache[cache_key]
|
||||
kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
|
||||
self._load_kernel_cache[cache_key] = kernel_var_name
|
||||
|
||||
keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem")
|
||||
self.writeline(f"if ({kernel_var_name} == nullptr) {{")
|
||||
deferred_gpu_kernel_line = DeferredGpuKernelLine(
|
||||
kernel_name,
|
||||
(
|
||||
" "
|
||||
+ kernel_var_name
|
||||
+ ' = loadKernel("%s", "%s", %s, this->cubin_dir_);'
|
||||
if V.graph.aot_mode
|
||||
else " " + kernel_var_name + ' = loadKernel("%s", "%s", %s);'
|
||||
),
|
||||
keys,
|
||||
self.additional_files,
|
||||
)
|
||||
self.writeline(deferred_gpu_kernel_line)
|
||||
self.writeline("}")
|
||||
return kernel_var_name
|
||||
|
||||
def generate_args_decl(self, call_args, arg_types, arg_signatures):
|
||||
new_args: list[str] = []
|
||||
|
||||
# Add more cases for other types as needed
|
||||
|
|
@ -427,26 +354,24 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
var_name = f"var_{next(self.arg_var_id)}"
|
||||
# ignore nvTmaDesc, as host-side TMA descriptors need
|
||||
# to be passed to the compiled Triton kernel by value
|
||||
if isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc":
|
||||
if arg.endswith(".item()"):
|
||||
# Need to declare a scalar in this case
|
||||
arg = arg[:-7]
|
||||
self.codegen_tensor_item(
|
||||
arg_type,
|
||||
arg,
|
||||
var_name,
|
||||
)
|
||||
else:
|
||||
device_ptr_type = self.device_codegen.cpp_device_ptr()
|
||||
self.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
|
||||
)
|
||||
if isinstance(arg_type, UnwrapUnspecArg) and arg_signature != "nvTmaDesc":
|
||||
self.codegen_tensor_item(
|
||||
arg_type.dtype,
|
||||
arg,
|
||||
var_name,
|
||||
indented_buffer=code,
|
||||
)
|
||||
elif isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc":
|
||||
device_ptr_type = self.device_codegen.cpp_device_ptr()
|
||||
code.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());"
|
||||
)
|
||||
)
|
||||
elif arg_type in (sympy.Integer, int):
|
||||
self.writeline(f"int {var_name} = {cexpr(arg)};")
|
||||
code.writeline(f"int {var_name} = {cexpr(arg)};")
|
||||
elif arg_type in (sympy.Float, float):
|
||||
self.writeline(f"float {var_name} = {cexpr(arg)};")
|
||||
code.writeline(f"float {var_name} = {cexpr(arg)};")
|
||||
# For symbolic call arguments, examine the arg signatures from triton meta
|
||||
# to explicitly cast to the right type
|
||||
# Reason: `auto` can infer unexpected type against kernel input signature.
|
||||
|
|
@ -455,11 +380,11 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
and arg_signature is not None
|
||||
and arg_signature in signature2dtype.keys()
|
||||
):
|
||||
self.writeline(
|
||||
code.writeline(
|
||||
f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};"
|
||||
)
|
||||
else:
|
||||
self.writeline(f"auto {var_name} = {cexpr(arg)};")
|
||||
code.writeline(f"auto {var_name} = {cexpr(arg)};")
|
||||
new_args.append(f"&{var_name}")
|
||||
|
||||
for arg, arg_type, arg_signature in zip_longest(
|
||||
|
|
@ -478,61 +403,34 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
|
||||
return ", ".join(new_args)
|
||||
|
||||
def generate_default_grid(
|
||||
self,
|
||||
kernel_name: str,
|
||||
grid_args: list[Any],
|
||||
gpu: bool = True,
|
||||
grid_callable: Optional[Callable[..., Any]] = default_grid_fn,
|
||||
**grid_extra_kwargs,
|
||||
):
|
||||
"""
|
||||
Generate grid configs for launching a CUDA kernel using the grid
|
||||
function from triton_heuristics. Because its computation needs
|
||||
to read kernel config after autotune, it is done in a deferred way
|
||||
using DeferredGpuDefaultGrid.
|
||||
"""
|
||||
assert gpu, "CppWrapperGpu.generate_default_grid does not support non-GPU"
|
||||
return DeferredGpuDefaultGrid(
|
||||
kernel_name, grid_args, grid_callable, **grid_extra_kwargs
|
||||
)
|
||||
|
||||
def generate_kernel_call(
|
||||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
grid=None,
|
||||
device_index=None,
|
||||
gpu=True,
|
||||
*,
|
||||
device=None,
|
||||
triton=True,
|
||||
arg_types=None,
|
||||
raw_args=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
autotune_configs=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
"""
|
||||
Override the default value of argument 'gpu' to True here.
|
||||
generate_kernel_call can still be called with gpu=False because of
|
||||
a mix of cpu kernels and gpu kernels.
|
||||
"""
|
||||
if not gpu:
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
if device.type == "cpu":
|
||||
# Even in CppWrapperGpu, we may see cpp kernels
|
||||
return CppWrapperCpu.generate_kernel_call(
|
||||
self,
|
||||
kernel_name,
|
||||
call_args,
|
||||
grid,
|
||||
device_index,
|
||||
gpu,
|
||||
triton,
|
||||
arg_types,
|
||||
raw_args,
|
||||
grid_fn,
|
||||
triton_meta,
|
||||
autotune_configs,
|
||||
grid_extra_kwargs,
|
||||
device=device,
|
||||
triton=triton,
|
||||
arg_types=arg_types,
|
||||
raw_args=raw_args,
|
||||
triton_meta=triton_meta,
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
@ -545,110 +443,38 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
self,
|
||||
kernel_name,
|
||||
call_args,
|
||||
grid,
|
||||
device_index,
|
||||
gpu,
|
||||
triton,
|
||||
arg_types,
|
||||
raw_args,
|
||||
grid_fn,
|
||||
triton_meta,
|
||||
autotune_configs,
|
||||
grid_extra_kwargs,
|
||||
device=device,
|
||||
triton=triton,
|
||||
arg_types=arg_types,
|
||||
raw_args=raw_args,
|
||||
triton_meta=triton_meta,
|
||||
)
|
||||
|
||||
if device_index is None:
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
device_index = current_device.index
|
||||
|
||||
stream = (
|
||||
"stream"
|
||||
if V.graph.aot_mode
|
||||
else self.write_get_raw_stream(device_index, V.graph)
|
||||
else self.write_get_raw_stream(device.index, V.graph)
|
||||
)
|
||||
|
||||
if triton:
|
||||
device_index, call_args = self.prepare_triton_kernel_call(
|
||||
device_index, call_args
|
||||
call_args, arg_types = self.prepare_triton_wrapper_args(
|
||||
call_args, arg_types
|
||||
)
|
||||
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
|
||||
arg_signatures = []
|
||||
if (
|
||||
triton_meta is not None
|
||||
and triton_meta.get("configs")
|
||||
and triton_meta.get("signature")
|
||||
):
|
||||
if triton_version_uses_attrs_dict():
|
||||
signatures = triton_meta["signature"]
|
||||
arg_signatures = [
|
||||
val for val in signatures.values() if val != "constexpr"
|
||||
]
|
||||
call_args = [
|
||||
call_arg
|
||||
for call_arg, arg_name in zip(call_args, signatures)
|
||||
if signatures[arg_name] != "constexpr"
|
||||
]
|
||||
arg_types = [
|
||||
arg_type
|
||||
for arg_type, arg_name in zip(arg_types, signatures)
|
||||
if signatures[arg_name] != "constexpr"
|
||||
]
|
||||
assert len(call_args) == len(arg_signatures), (
|
||||
f"len of the following lists do not match: {call_args=} {arg_signatures=}"
|
||||
)
|
||||
else:
|
||||
# args with value 1 are added into equal_to_1 and constants
|
||||
# in triton_meta (in the Python codegen) which makes them
|
||||
# inlined in the PTX and compiled CUBIN
|
||||
equal_to_1 = triton_meta["configs"][0].equal_to_1
|
||||
call_args = [
|
||||
arg for i, arg in enumerate(call_args) if i not in equal_to_1
|
||||
]
|
||||
arg_types = [
|
||||
t for i, t in enumerate(arg_types) if i not in equal_to_1
|
||||
]
|
||||
# extract the arg signatures from triton_meta
|
||||
arg_signatures = triton_meta["signature"].values()
|
||||
arg_signatures = [
|
||||
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
|
||||
]
|
||||
call_args_str = self.generate_args_decl(
|
||||
call_args, arg_types, arg_signatures
|
||||
)
|
||||
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
||||
self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
|
||||
|
||||
grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
|
||||
self.writeline(
|
||||
DeferredGpuGridLine(kernel_name, grid_var, grid, autotune_configs)
|
||||
)
|
||||
|
||||
kernel_var_name = (
|
||||
f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
|
||||
)
|
||||
# add debug printer code for all triton kernel related calls
|
||||
wrapper_name = f"call_{kernel_name}"
|
||||
if wrapper_name not in self._triton_call_wrappers:
|
||||
self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper(
|
||||
wrapper_name, kernel_name, arg_types
|
||||
)
|
||||
call_args.append(stream)
|
||||
if V.graph.aot_mode:
|
||||
call_args.append("kernels")
|
||||
call_args.append("this->cubin_dir_")
|
||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||
debug_printer_manager.set_printer_args(
|
||||
call_args, kernel_name, arg_types, None
|
||||
call_args[: len(arg_types)], kernel_name, arg_types, None
|
||||
)
|
||||
with debug_printer_manager:
|
||||
self.writeline(f"if ({grid_var}.is_non_zero()) {{")
|
||||
self.writeline(
|
||||
DeferredGpuKernelLine(
|
||||
kernel_name,
|
||||
r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
|
||||
kernel_var_name,
|
||||
f"{grid_var}.grid_x",
|
||||
f"{grid_var}.grid_y",
|
||||
f"{grid_var}.grid_z",
|
||||
kernel_args_var,
|
||||
stream,
|
||||
),
|
||||
("num_warps", "shared_mem"),
|
||||
self.additional_files,
|
||||
),
|
||||
)
|
||||
self.writeline("}")
|
||||
self.writeline(f"{wrapper_name}({', '.join(call_args)});")
|
||||
else:
|
||||
casted = []
|
||||
for arg_type, arg in zip(arg_types, call_args):
|
||||
|
|
@ -659,5 +485,34 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
call_args_str = ", ".join(casted)
|
||||
self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});")
|
||||
|
||||
@staticmethod
|
||||
def prepare_triton_wrapper_args(
|
||||
call_args: list[Any], arg_types: list[Any]
|
||||
) -> tuple[list[Any], list[Any]]:
|
||||
assert len(call_args) == len(arg_types), (call_args, arg_types)
|
||||
new_args = []
|
||||
new_args_types = []
|
||||
for arg, arg_type in zip(call_args, arg_types):
|
||||
if isinstance(arg, str):
|
||||
if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg):
|
||||
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
||||
arg_type = UnwrapUnspecArg(dtype=arg_type)
|
||||
new_args.append(arg)
|
||||
elif isinstance(arg, bool):
|
||||
new_args.append(str(arg).lower())
|
||||
elif isinstance(arg, (int, float, SymbolicCallArg)):
|
||||
new_args.append(str(arg))
|
||||
else:
|
||||
new_args.append(cexpr(V.graph.sizevars.simplify(arg)))
|
||||
new_args_types.append(arg_type)
|
||||
return new_args, new_args_types
|
||||
|
||||
def make_zero_buffer(self, name):
|
||||
return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UnwrapUnspecArg:
|
||||
"""Marker that we need to call .item() on the tensor"""
|
||||
|
||||
dtype: torch_dtype
|
||||
|
|
|
|||
|
|
@ -341,7 +341,6 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
gpu=True,
|
||||
triton=False,
|
||||
arg_types=arg_types,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -63,22 +63,6 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
|
|||
} \\
|
||||
} while (0);
|
||||
|
||||
namespace {
|
||||
|
||||
struct Grid {
|
||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||
uint32_t grid_x;
|
||||
uint32_t grid_y;
|
||||
uint32_t grid_z;
|
||||
|
||||
bool is_non_zero() {
|
||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static inline CUfunction loadKernel(
|
||||
std::string filePath,
|
||||
const std::string &funcName,
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class DebugPrinterManager:
|
|||
continue
|
||||
if V.graph.cpp_wrapper:
|
||||
if arg_signatures is not None and isinstance(
|
||||
arg_signatures[i], (torch_dtype)
|
||||
arg_signatures[i], torch_dtype
|
||||
):
|
||||
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
||||
V.graph.wrapper_code.writeline(
|
||||
|
|
|
|||
|
|
@ -1633,7 +1633,7 @@ class HalideKernel(SIMDKernel):
|
|||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
gpu=False, # grid/stream is handled internally in halide
|
||||
device=current_device,
|
||||
triton=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -569,7 +569,7 @@ class MetalKernel(SIMDKernel):
|
|||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
args,
|
||||
gpu=False, # TODO: Fix me, MPS does not expose streams now
|
||||
device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now
|
||||
triton=False,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import functools
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
|
@ -11,11 +10,6 @@ from torch.utils._ordered_set import OrderedSet
|
|||
from .. import config
|
||||
from ..codecache import code_hash, CodeCacheFuture, get_path
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.triton_heuristics import (
|
||||
cooperative_reduction_grid,
|
||||
grid,
|
||||
maybe_cooperative_reduction_grid,
|
||||
)
|
||||
from ..utils import cache_on_self, IndentedBuffer
|
||||
from ..virtualized import V
|
||||
from .common import TensorArg, WorkspaceArg
|
||||
|
|
@ -196,19 +190,6 @@ class MultiKernel:
|
|||
kernel.args.workspace_args = workspace_args
|
||||
return workspace_args
|
||||
|
||||
def get_grid_fn(self):
|
||||
fns = OrderedSet(kernel._get_grid_fn() for kernel in self.kernels)
|
||||
if len(fns) == 1:
|
||||
return fns.pop()
|
||||
elif len(fns) == 2:
|
||||
assert fns == OrderedSet([cooperative_reduction_grid, grid])
|
||||
V.graph.wrapper_code.add_import_once(
|
||||
f"from {maybe_cooperative_reduction_grid.__module__} import maybe_cooperative_reduction_grid"
|
||||
)
|
||||
return maybe_cooperative_reduction_grid
|
||||
else:
|
||||
raise NotImplementedError(fns)
|
||||
|
||||
def call_kernel(self, kernel_name):
|
||||
"""
|
||||
Collect the union of arguments from all subkernels as the arguments
|
||||
|
|
@ -222,31 +203,21 @@ class MultiKernel:
|
|||
assert call_args == other_call_args, (call_args, other_call_args)
|
||||
assert arg_types == other_arg_types
|
||||
|
||||
grid: list[Any] = []
|
||||
|
||||
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
|
||||
# for the second pass of cpp-wrapper codegen, we should call
|
||||
# the fast kernel directly
|
||||
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
|
||||
# numels for all subkernels should be the same. Use kernels[0] here
|
||||
self.kernels[0].add_numel_to_call_args_and_grid(
|
||||
kernel_name, call_args, arg_types, grid
|
||||
)
|
||||
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
|
||||
|
||||
for ws in self.kernels[0].args.workspace_args:
|
||||
V.graph.wrapper_code.generate_workspace_allocation(ws)
|
||||
|
||||
grid_fn = self.get_grid_fn()
|
||||
grid = V.graph.wrapper_code.generate_default_grid(
|
||||
kernel_name, grid, grid_callable=grid_fn
|
||||
)
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
kernel_name,
|
||||
call_args,
|
||||
grid,
|
||||
arg_types=arg_types,
|
||||
grid_fn=grid_fn.__name__,
|
||||
)
|
||||
|
||||
for ws in reversed(self.kernels[0].args.workspace_args):
|
||||
|
|
|
|||
|
|
@ -196,12 +196,9 @@ class ROCmTemplateKernel(ROCmKernel):
|
|||
kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None")
|
||||
if V.graph.cpp_wrapper:
|
||||
arg_types.append("uint8_t*")
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
kernel_args,
|
||||
device_index=current_device.index,
|
||||
gpu=True,
|
||||
triton=False,
|
||||
arg_types=arg_types,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1549,13 +1549,10 @@ class SIMDScheduling(BaseScheduling):
|
|||
|
||||
if config.benchmark_kernel:
|
||||
num_gb = kernel.estimate_kernel_num_bytes() / 1e9
|
||||
grid_args = V.graph.sizevars.size_hints(kernel.call_sizes)
|
||||
assert kernel.meta is not None, "meta is None"
|
||||
grid = kernel.grid_fn(*grid_args, kernel.meta)
|
||||
src_code = (
|
||||
f"{kernel.imports_for_benchmark_kernel()}\n"
|
||||
f"{src_code}\n"
|
||||
f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}"
|
||||
f"{kernel.codegen_kernel_benchmark(num_gb).getvalue()}"
|
||||
)
|
||||
|
||||
if only_gen_src_code:
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from .. import config, ir, metrics
|
|||
from ..async_compile import AsyncCompile
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime import triton_heuristics
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
AutotuneHint,
|
||||
|
|
@ -41,10 +42,6 @@ from ..runtime.hints import (
|
|||
TRITON_MAX_RSPLIT,
|
||||
)
|
||||
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
|
||||
from ..runtime.triton_heuristics import (
|
||||
cooperative_reduction_grid,
|
||||
grid as default_grid_fn,
|
||||
)
|
||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
|
|
@ -87,7 +84,6 @@ from .simd import (
|
|||
IterationRanges,
|
||||
IterationRangesEntry,
|
||||
IterationRangesRoot,
|
||||
pexpr,
|
||||
SIMDKernel,
|
||||
SIMDScheduling,
|
||||
)
|
||||
|
|
@ -98,6 +94,7 @@ from .triton_utils import (
|
|||
should_unwrap_unspec_arg,
|
||||
signature_to_meta,
|
||||
)
|
||||
from .wrapper import SymbolicCallArg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -3190,7 +3187,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
self.post_loop_combine.clear()
|
||||
self.post_loop_store.clear()
|
||||
|
||||
def codegen_kernel_benchmark(self, num_gb, grid=None):
|
||||
def kernel_benchmark_extra_args(self) -> list[str]:
|
||||
args = []
|
||||
if self.need_numel_args():
|
||||
numel_args: list[sympy.Expr] = []
|
||||
self.add_numel_to_call_args("", numel_args, [])
|
||||
for arg in numel_args:
|
||||
if isinstance(arg, int):
|
||||
args.append(str(arg))
|
||||
elif isinstance(arg, SymbolicCallArg):
|
||||
args.append(str(V.graph.sizevars.size_hint(arg.inner_expr)))
|
||||
elif isinstance(arg, sympy.Expr):
|
||||
args.append(str(V.graph.sizevars.size_hint(arg)))
|
||||
else:
|
||||
raise ValueError(f"Unsupported numel argument type: {type(arg)}")
|
||||
return args
|
||||
|
||||
def codegen_kernel_benchmark(self, num_gb):
|
||||
result = IndentedBuffer()
|
||||
_argdefs, call_args, signature, _ = self.args.python_argdefs()
|
||||
|
||||
|
|
@ -3231,25 +3244,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
f"Don't find the buffer or const tensor for {arg_name}"
|
||||
)
|
||||
var_names.append(var_name)
|
||||
var_names.extend(self.kernel_benchmark_extra_args())
|
||||
result.writeline(f"return {', '.join(var_names)},")
|
||||
|
||||
result.writelines(["\n", "\n", "def call(args):"])
|
||||
if grid is None:
|
||||
grid = []
|
||||
extra_args = []
|
||||
extra_args_str = None
|
||||
for tree in self.active_range_trees():
|
||||
expr = pexpr(V.graph.sizevars.size_hint(tree.numel))
|
||||
extra_args.append(expr)
|
||||
if not tree.is_reduction:
|
||||
grid.append(expr)
|
||||
if self.need_numel_args():
|
||||
extra_args_str = ", ".join(map(str, extra_args)) + ", "
|
||||
else:
|
||||
extra_args_str = ""
|
||||
grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})"
|
||||
else:
|
||||
grid_arg = f"grid={grid}"
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
index = current_device.index
|
||||
with result.indent():
|
||||
|
|
@ -3261,7 +3259,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
stream_name = f"stream{index}"
|
||||
result.writeline(f"{stream_name} = get_raw_stream({index})")
|
||||
result.writeline(
|
||||
f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
|
||||
f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})"
|
||||
)
|
||||
|
||||
# benchmark all configs
|
||||
|
|
@ -3273,7 +3271,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
V.graph.device_ops.set_device(index)
|
||||
) # no-op to ensure context
|
||||
result.writeline(
|
||||
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
|
||||
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)"
|
||||
)
|
||||
|
||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||
|
|
@ -3301,7 +3299,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
from torch._dynamo.testing import rand_strided
|
||||
{}
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
|
||||
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
|
||||
)
|
||||
|
||||
|
|
@ -3488,6 +3485,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
|
||||
inductor_meta = {
|
||||
# Triton will not accept an OrderedSet for autotune_hints
|
||||
"grid_type": self._get_grid_type().__name__,
|
||||
"autotune_hints": set(self.autotune_hints), # noqa: set_linter
|
||||
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
||||
"mutated_arg_names": mutated_args,
|
||||
|
|
@ -3639,12 +3637,22 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
if tree.prefix == "x" and self.no_x_dim:
|
||||
code.writeline("XBLOCK: tl.constexpr = 1")
|
||||
|
||||
def _get_grid_fn(self):
|
||||
def _get_grid_type(self) -> type[triton_heuristics.GridExpr]:
|
||||
n = sum([int(not tree.is_reduction) for tree in self.range_trees])
|
||||
if self.cooperative_reduction:
|
||||
return cooperative_reduction_grid
|
||||
return default_grid_fn
|
||||
assert n == 1
|
||||
return triton_heuristics.CooperativeReductionGrid
|
||||
elif n == 1:
|
||||
return triton_heuristics.Grid1D
|
||||
elif n == 2:
|
||||
if any(map(self.needs_yz_grid_overflow, self.range_trees)):
|
||||
return triton_heuristics.Grid2DWithYZOverflow
|
||||
return triton_heuristics.Grid2D
|
||||
elif n == 3:
|
||||
return triton_heuristics.Grid3D
|
||||
raise ValueError(f"Unsupported number of dimensions: {n}")
|
||||
|
||||
def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):
|
||||
def add_numel_to_call_args(self, name, call_args, arg_types):
|
||||
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
|
||||
for tree in self.range_trees:
|
||||
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
|
||||
|
|
@ -3655,31 +3663,21 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
if not tree.is_reduction or self.inside_reduction:
|
||||
call_args.append(expr)
|
||||
arg_types.append(type(expr))
|
||||
if tree.grid_dim is not None:
|
||||
grid.append(expr)
|
||||
|
||||
def call_kernel(self, name: str, node: Optional[IRNode] = None):
|
||||
wrapper = V.graph.wrapper_code
|
||||
wrapper.write_triton_header_once()
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
grid: list[Any] = []
|
||||
self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
self.add_numel_to_call_args(name, call_args, arg_types)
|
||||
|
||||
for ws in self.args.workspace_args:
|
||||
wrapper.generate_workspace_allocation(ws)
|
||||
|
||||
grid_fn = self._get_grid_fn()
|
||||
grid = wrapper.generate_default_grid(name, grid, grid_callable=grid_fn)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
grid,
|
||||
current_device.index,
|
||||
gpu=current_device.type != "cpu",
|
||||
triton=True,
|
||||
arg_types=arg_types,
|
||||
grid_fn=grid_fn.__name__,
|
||||
triton_meta=self.triton_meta,
|
||||
)
|
||||
|
||||
|
|
@ -3738,12 +3736,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
key = f"tl.program_id({entry.grid_dim})"
|
||||
# y_grid has a limit, so express it in terms of y and z in case of overflow.
|
||||
# z grid is only exercised when max_tiles == 3 (off by default).
|
||||
if (
|
||||
entry.grid_dim == 1
|
||||
and not entry.has_zdim
|
||||
and not self.cooperative_reduction
|
||||
and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid())
|
||||
):
|
||||
if self.needs_yz_grid_overflow(entry):
|
||||
# For ynumel larger than max_ygrid, we need to use zdim.
|
||||
# For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z).
|
||||
# So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset.
|
||||
|
|
@ -3753,6 +3746,14 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
return f"{pid}.to({self.index_dtype})"
|
||||
return pid
|
||||
|
||||
def needs_yz_grid_overflow(self, entry: IterationRangesRoot) -> bool:
|
||||
return (
|
||||
entry.grid_dim == 1
|
||||
and not entry.has_zdim
|
||||
and not self.cooperative_reduction
|
||||
and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid())
|
||||
)
|
||||
|
||||
def max_block(self, prefix: str) -> int:
|
||||
if self.fixed_config:
|
||||
return self.fixed_config[f"{prefix.upper()}BLOCK"]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import itertools
|
|||
import logging
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
|
|
@ -14,7 +13,10 @@ from torch.utils._ordered_set import OrderedSet
|
|||
from .. import config, metrics
|
||||
from ..runtime.hints import DeviceProperties
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..runtime.triton_heuristics import grid_combo_kernels
|
||||
from ..runtime.triton_heuristics import (
|
||||
RoundRobinComboKernelGrid,
|
||||
SequentialComboKernelGrid,
|
||||
)
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
from ..utils import Placeholder, triton_version_uses_attrs_dict
|
||||
from ..virtualized import V
|
||||
|
|
@ -283,6 +285,8 @@ class ComboKernel(Kernel):
|
|||
grid(...): codegen the grid size for launching the combo kernel.
|
||||
"""
|
||||
|
||||
grid_expr = SequentialComboKernelGrid
|
||||
|
||||
@classmethod
|
||||
def codegen_pid_range(
|
||||
cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
|
||||
|
|
@ -321,42 +325,6 @@ class ComboKernel(Kernel):
|
|||
else:
|
||||
code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}")
|
||||
|
||||
@classmethod
|
||||
def grid(
|
||||
cls,
|
||||
sub_kernel_numels: list[list[int]],
|
||||
x_blocks_list: list[Union[str, int]],
|
||||
dynamic_shape: bool,
|
||||
) -> tuple[Any, ...]:
|
||||
xnumel = list(x_blocks_list)
|
||||
ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
|
||||
znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
|
||||
|
||||
if dynamic_shape:
|
||||
ynumel = None if None in ynumel else ynumel
|
||||
znumel = None if None in znumel else znumel
|
||||
else:
|
||||
# TODO: improve 1d/2d mixed cases
|
||||
ynumel = (
|
||||
None
|
||||
if any(e is None for e in cast(list[Any], ynumel))
|
||||
else max(cast(Iterable[int], ynumel))
|
||||
)
|
||||
znumel = (
|
||||
None
|
||||
if any(e is None for e in cast(list[Any], znumel))
|
||||
else max(cast(Iterable[int], znumel))
|
||||
)
|
||||
|
||||
numels = (
|
||||
(xnumel,)
|
||||
if not ynumel
|
||||
else (ynumel, xnumel)
|
||||
if not znumel
|
||||
else (znumel, ynumel, xnumel)
|
||||
)
|
||||
return numels
|
||||
|
||||
class RoundRobinDispatch:
|
||||
"""
|
||||
The dispatcher which dispatches the subkernels in a round robin manner:
|
||||
|
|
@ -368,6 +336,8 @@ class ComboKernel(Kernel):
|
|||
grid(...): codegen the grid size for launching the combo kernel.
|
||||
"""
|
||||
|
||||
grid_expr = RoundRobinComboKernelGrid
|
||||
|
||||
@classmethod
|
||||
def codegen_pid_range(
|
||||
cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
|
||||
|
|
@ -381,51 +351,6 @@ class ComboKernel(Kernel):
|
|||
with code.indent():
|
||||
code.splice(f"pid_offset = pid // {num_kernels}")
|
||||
|
||||
@classmethod
|
||||
def grid(
|
||||
cls,
|
||||
sub_kernel_numels: list[list[int]],
|
||||
x_blocks_list: list[Union[str, int]],
|
||||
dynamic_shape: bool,
|
||||
) -> tuple[Any, ...]:
|
||||
xnumel = x_blocks_list
|
||||
# set no_x_dim xnumels to 0
|
||||
xnumel_x_dim = [max(e, 0) for e in xnumel]
|
||||
ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
|
||||
znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
|
||||
|
||||
# TODO: support 1d/2d mixed cases
|
||||
xnumel = (
|
||||
None
|
||||
if any(e is None for e in xnumel)
|
||||
else xnumel
|
||||
if dynamic_shape
|
||||
else max(xnumel_x_dim) # type: ignore[type-var, arg-type]
|
||||
)
|
||||
ynumel = (
|
||||
None
|
||||
if any(e is None for e in ynumel)
|
||||
else ynumel
|
||||
if dynamic_shape
|
||||
else max(ynumel) # type: ignore[type-var, arg-type]
|
||||
)
|
||||
znumel = (
|
||||
None
|
||||
if any(e is None for e in znumel)
|
||||
else znumel
|
||||
if dynamic_shape
|
||||
else max(znumel) # type: ignore[type-var, arg-type]
|
||||
)
|
||||
|
||||
numels = (
|
||||
(xnumel,)
|
||||
if not ynumel
|
||||
else (ynumel, xnumel)
|
||||
if not znumel
|
||||
else (znumel, ynumel, xnumel)
|
||||
)
|
||||
return numels
|
||||
|
||||
def __init__(
|
||||
self, enable_autotune: bool = False, mixed_sizes: bool = False
|
||||
) -> None:
|
||||
|
|
@ -638,7 +563,7 @@ class ComboKernel(Kernel):
|
|||
# mixed_sizes is used for optimize_mask, so it only allows sequential dispatch
|
||||
# Not mixed sizes on y dim technically is ok to use round robin as wells.
|
||||
if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list):
|
||||
# str in min_x_blocks_list means a dynamic shape
|
||||
# str in x_numels_list means a dynamic shape
|
||||
self.dispatch_class = ComboKernel.SequentialDispatch
|
||||
return
|
||||
# A negative x_blocks_list element means the kernel is not tunable,
|
||||
|
|
@ -675,7 +600,11 @@ class ComboKernel(Kernel):
|
|||
}
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
mutated_args = self.get_mutated_args_sub_kernels()
|
||||
dispatch = self.dispatch_class
|
||||
assert dispatch is not None
|
||||
inductor_meta = {
|
||||
"grid_type": dispatch.grid_expr.__name__,
|
||||
"combo_grid_meta": self.combo_grid_meta(),
|
||||
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
||||
"mutated_arg_names": mutated_args,
|
||||
**TritonKernel.inductor_meta_common(),
|
||||
|
|
@ -771,8 +700,8 @@ class ComboKernel(Kernel):
|
|||
self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
|
||||
return argdefs
|
||||
|
||||
def add_numel_to_call_args_and_grid(
|
||||
self, name: str, call_args: list[Any], arg_types: list[Any], grid: list[Any]
|
||||
def add_numel_to_call_args(
|
||||
self, name: str, call_args: list[Any], arg_types: list[Any]
|
||||
) -> None:
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
|
|
@ -785,40 +714,20 @@ class ComboKernel(Kernel):
|
|||
expr = V.graph.wrapper_code.generate_numel_expr(
|
||||
name, tree, suffix=str(num)
|
||||
)
|
||||
if not tree.is_reduction:
|
||||
assert isinstance(grid[i][num], str), (
|
||||
f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
)
|
||||
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
||||
assert grid[i][num] == numel_sign + numel_name, (
|
||||
f"numel args mismatch: {grid[i][num]} vs {numel_name}"
|
||||
)
|
||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
call_args.append(expr)
|
||||
arg_types.append(type(expr))
|
||||
|
||||
def add_numel_to_call_args_and_grid_benchmark(
|
||||
self, extra_args: list[Any], grid: Union[list[Any], tuple[Any, ...]]
|
||||
) -> None:
|
||||
def kernel_benchmark_extra_args(self) -> list[str]:
|
||||
extra_args = []
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
expr = V.graph.sizevars.size_hint(tree.numel)
|
||||
if not tree.is_reduction:
|
||||
assert isinstance(grid[i][num], str), (
|
||||
f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
)
|
||||
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
||||
assert grid[i][num] == numel_sign + numel_name, (
|
||||
f"grid mismatch: {grid[i][num]} vs {numel_name}"
|
||||
)
|
||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
extra_args.append(expr)
|
||||
extra_args.append(str(V.graph.sizevars.size_hint(tree.numel)))
|
||||
return extra_args
|
||||
|
||||
def codegen_kernel(self, name: Optional[str] = None) -> str:
|
||||
# TODO: is it correct to use the first sub kernel's heuristics?
|
||||
|
|
@ -890,12 +799,9 @@ class ComboKernel(Kernel):
|
|||
|
||||
return code.getvalue()
|
||||
|
||||
def codegen_kernel_benchmark(
|
||||
self, num_gb: float, grid: Optional[list[Any]] = None
|
||||
) -> IndentedBuffer:
|
||||
def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
|
||||
result = IndentedBuffer()
|
||||
_argdefs, call_args, signature, _ = self.args.python_argdefs()
|
||||
|
||||
result.writelines(["", "", "def get_args():"])
|
||||
with result.indent():
|
||||
name_cnt = itertools.count()
|
||||
|
|
@ -934,38 +840,11 @@ class ComboKernel(Kernel):
|
|||
f"Don't find the buffer or const tensor for {arg_name}"
|
||||
)
|
||||
var_names.append(var_name)
|
||||
if self.dynamic_shape_args:
|
||||
var_names.extend(self.kernel_benchmark_extra_args())
|
||||
result.writeline(f"return {', '.join(var_names)},")
|
||||
|
||||
result.writelines(["\n", "\n", "def call(args):"])
|
||||
if grid is None:
|
||||
assert self.dispatch_class is not None
|
||||
dynamic_shape = self.dynamic_shape_args != []
|
||||
grid_tuple = self.dispatch_class.grid(
|
||||
self.grids, self.x_numels_list, dynamic_shape
|
||||
)
|
||||
extra_args_str = ""
|
||||
extra_args: list[Any] = []
|
||||
if dynamic_shape:
|
||||
self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple)
|
||||
# convert nested list to list of str
|
||||
grid_tuple = tuple(
|
||||
"[" + ", ".join(pexpr(item) for item in e) + ",]"
|
||||
for e in grid_tuple
|
||||
)
|
||||
extra_args_str = ", ".join(map(str, extra_args)) + ", "
|
||||
min_blocks = None
|
||||
else:
|
||||
min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels)
|
||||
grid_str = ", ".join(pexpr(item) for item in grid_tuple)
|
||||
grid_extra_kwargs = (
|
||||
f"num_kernels={len(self.sub_kernels)}, "
|
||||
f"min_blocks={min_blocks}, "
|
||||
f"is_sequential={self.dispatch_class is self.SequentialDispatch}"
|
||||
)
|
||||
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
||||
grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})"
|
||||
else:
|
||||
grid_arg = f"grid={grid}"
|
||||
index = V.graph.get_current_device_or_throw().index
|
||||
with result.indent():
|
||||
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
|
||||
|
|
@ -976,7 +855,7 @@ class ComboKernel(Kernel):
|
|||
stream_name = f"stream{index}"
|
||||
result.writeline(f"{stream_name} = get_raw_stream({index})")
|
||||
result.writeline(
|
||||
f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
|
||||
f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})"
|
||||
)
|
||||
|
||||
# benchmark all configs
|
||||
|
|
@ -988,7 +867,7 @@ class ComboKernel(Kernel):
|
|||
V.graph.device_ops.set_device(index)
|
||||
) # no-op to ensure context
|
||||
result.writeline(
|
||||
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
|
||||
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)"
|
||||
)
|
||||
|
||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||
|
|
@ -1016,7 +895,6 @@ class ComboKernel(Kernel):
|
|||
from torch._dynamo.testing import rand_strided
|
||||
{}
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
|
||||
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
|
||||
)
|
||||
|
||||
|
|
@ -1053,77 +931,48 @@ class ComboKernel(Kernel):
|
|||
|
||||
wrapper = V.graph.wrapper_code
|
||||
assert self.dispatch_class is not None
|
||||
dynamic_shape = self.dynamic_shape_args != []
|
||||
grid = list(
|
||||
self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape)
|
||||
if self.dynamic_shape_args:
|
||||
self.add_numel_to_call_args(name, call_args, arg_types)
|
||||
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
triton=True,
|
||||
arg_types=arg_types,
|
||||
)
|
||||
|
||||
def combo_grid_meta(self) -> dict[str, Any]:
|
||||
dynamic_shape = bool(self.dynamic_shape_args)
|
||||
num_kernels = len(self.sub_kernels)
|
||||
min_blocks = (
|
||||
max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None
|
||||
)
|
||||
is_sequential = self.dispatch_class is self.SequentialDispatch
|
||||
if dynamic_shape:
|
||||
self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
|
||||
# convert nested list to list of str
|
||||
# grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid)
|
||||
if not self.enable_autotune and not dynamic_shape:
|
||||
launch_grid = self.grid_no_autotune(
|
||||
grid, num_kernels, cast(int, min_blocks), is_sequential
|
||||
)
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
grid=launch_grid,
|
||||
arg_types=arg_types,
|
||||
grid_fn="",
|
||||
)
|
||||
return
|
||||
# autotuning is enabled
|
||||
grid = wrapper.generate_default_grid(
|
||||
name,
|
||||
list(grid),
|
||||
grid_callable=grid_combo_kernels,
|
||||
num_kernels=num_kernels,
|
||||
min_blocks=min_blocks,
|
||||
is_sequential=is_sequential,
|
||||
default_meta=None if self.enable_autotune else self.get_default_meta(),
|
||||
)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
grid,
|
||||
V.graph.get_current_device_or_throw().index,
|
||||
gpu=True,
|
||||
triton=True,
|
||||
arg_types=arg_types,
|
||||
grid_fn="grid_combo_kernels",
|
||||
grid_extra_kwargs=(
|
||||
f"num_kernels={num_kernels}, "
|
||||
f"min_blocks={min_blocks}, "
|
||||
f"is_sequential={is_sequential}, "
|
||||
f"default_meta={None if self.enable_autotune else self.get_default_meta()}"
|
||||
),
|
||||
)
|
||||
|
||||
def grid_no_autotune(
|
||||
self,
|
||||
grid: Union[tuple[Any], list[Any]],
|
||||
num_kernels: int,
|
||||
min_blocks: int,
|
||||
is_sequential: bool,
|
||||
) -> list[int]:
|
||||
meta = self.get_default_meta()
|
||||
grid_func = grid_combo_kernels(
|
||||
*grid,
|
||||
num_kernels=num_kernels,
|
||||
min_blocks=min_blocks,
|
||||
is_sequential=is_sequential,
|
||||
)
|
||||
return grid_func(meta)
|
||||
|
||||
def get_default_meta(self) -> dict[str, int]:
|
||||
if "YBLOCK" in self.block_args:
|
||||
meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d}
|
||||
if not self.enable_autotune:
|
||||
if "YBLOCK" in self.block_args:
|
||||
default_config = {
|
||||
"XBLOCK": self.block_size_2d,
|
||||
"YBLOCK": self.block_size_2d,
|
||||
}
|
||||
else:
|
||||
default_config = {"XBLOCK": self.block_size_1d}
|
||||
else:
|
||||
meta = {"XBLOCK": self.block_size_1d}
|
||||
default_config = None
|
||||
|
||||
meta = {
|
||||
"num_kernels": num_kernels,
|
||||
"min_blocks": min_blocks,
|
||||
"default_config": default_config,
|
||||
}
|
||||
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim
|
||||
for i, tree in enumerate(sub_kernel.range_trees):
|
||||
if not tree.is_reduction:
|
||||
numel_name = f"{tree.prefix}numel_{num}"
|
||||
if numel_name in self.dynamic_shape_args:
|
||||
meta[numel_name] = None
|
||||
else:
|
||||
meta[numel_name] = int(V.graph.sizevars.simplify(tree.numel))
|
||||
|
||||
return meta
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from torch._inductor.codegen.triton import (
|
|||
TritonCSEVariable,
|
||||
TritonKernel,
|
||||
)
|
||||
from torch._inductor.runtime.triton_heuristics import split_scan_grid
|
||||
from torch._inductor.runtime.triton_heuristics import SplitScanGrid
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
|
||||
|
|
@ -203,5 +203,5 @@ class TritonSplitScanKernel(TritonKernel):
|
|||
def _get_heuristic(self):
|
||||
return "split_scan"
|
||||
|
||||
def _get_grid_fn(self):
|
||||
return split_scan_grid
|
||||
def _get_grid_type(self) -> type[SplitScanGrid]:
|
||||
return SplitScanGrid
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from ..utils import (
|
|||
LineContext,
|
||||
sympy_product,
|
||||
sympy_str,
|
||||
sympy_subs,
|
||||
triton_version_uses_attrs_dict,
|
||||
)
|
||||
from ..virtualized import V
|
||||
|
|
@ -61,6 +62,7 @@ from .common import (
|
|||
WorkspaceArg,
|
||||
WorkspaceZeroMode,
|
||||
)
|
||||
from .cpp_utils import cexpr
|
||||
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
|
||||
|
||||
|
||||
|
|
@ -839,14 +841,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
import_str = f"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from {triton_heuristics.__name__} import (
|
||||
grid,
|
||||
split_scan_grid,
|
||||
grid_combo_kernels,
|
||||
start_graph,
|
||||
end_graph,
|
||||
cooperative_reduction_grid,
|
||||
)
|
||||
from {triton_heuristics.__name__} import start_graph, end_graph
|
||||
"""
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.splice(import_str)
|
||||
|
|
@ -1133,44 +1128,6 @@ class PythonWrapperCodegen(CodeGen):
|
|||
with debug_printer_manager:
|
||||
self.writeline(f"{kernel}({', '.join(args)})")
|
||||
|
||||
def generate_user_defined_triton_kernel(
|
||||
self,
|
||||
kernel_name: str,
|
||||
raw_args: list[Any],
|
||||
grid: list[Any],
|
||||
configs,
|
||||
triton_meta,
|
||||
constexprs,
|
||||
):
|
||||
grid_fn, code = user_defined_kernel_grid_fn_code(
|
||||
kernel_name, configs, grid, wrapper=self
|
||||
)
|
||||
if not (config.triton.autotune_at_compile_time and V.graph.cpp_wrapper):
|
||||
# When codegen the autotune block only, do no insert Triton kernel
|
||||
# code into the main block
|
||||
#
|
||||
# Must happen after free symbols are already codegened
|
||||
# Emit the grid wrapper function right before the call
|
||||
for line in code.split("\n"):
|
||||
self.writeline(line)
|
||||
|
||||
# Explicitly call the Python version of val_to_arg_str
|
||||
args = [PythonWrapperCodegen.val_to_arg_str(self, v) for v in raw_args]
|
||||
arg_types = [
|
||||
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
|
||||
for arg in raw_args
|
||||
]
|
||||
# Because generate_kernel_call can be overriden by a subclass, explicitly call
|
||||
# PythonWrapperCodegen.generate_kernel_call here
|
||||
PythonWrapperCodegen.generate_kernel_call(
|
||||
self,
|
||||
kernel_name,
|
||||
args,
|
||||
grid_fn=grid_fn,
|
||||
arg_types=arg_types,
|
||||
raw_args=raw_args,
|
||||
)
|
||||
|
||||
def _generate_tma_descriptor_call(self, desc, apply_size_hints=False):
|
||||
dims = desc.dims
|
||||
block_dims = desc.block_dims
|
||||
|
|
@ -1694,13 +1651,15 @@ class PythonWrapperCodegen(CodeGen):
|
|||
kwargs,
|
||||
restore_value_args,
|
||||
reset_to_zero_args,
|
||||
grids: list[list[Union[int, sympy.Expr]]],
|
||||
):
|
||||
from torch.utils._triton import patch_triton_dtype_repr
|
||||
|
||||
patch_triton_dtype_repr()
|
||||
|
||||
original_name = kernel.__name__
|
||||
|
||||
from ..runtime.triton_heuristics import (
|
||||
config_to_dict,
|
||||
FixedGrid,
|
||||
PrecomputedGrid,
|
||||
)
|
||||
from .common import (
|
||||
ConstexprArg,
|
||||
KernelArgType,
|
||||
|
|
@ -1708,7 +1667,10 @@ class PythonWrapperCodegen(CodeGen):
|
|||
TensorArg,
|
||||
TMADescriptorArg,
|
||||
)
|
||||
from .triton import gen_common_triton_imports, TritonKernel
|
||||
|
||||
patch_triton_dtype_repr()
|
||||
original_name = kernel.__name__
|
||||
signature: list[KernelArgType] = []
|
||||
constants: dict[str, Any] = {}
|
||||
arg_indices: list[int] = []
|
||||
|
|
@ -1839,22 +1801,67 @@ class PythonWrapperCodegen(CodeGen):
|
|||
if reset_to_zero_args:
|
||||
triton_meta["reset_to_zero"] = tuple(reset_to_zero_args)
|
||||
|
||||
if len(grids) == 1:
|
||||
# compute the grid in the wrapper and pass it in as an arg
|
||||
inductor_meta: dict[str, Any] = FixedGrid.setup_grid_as_args()
|
||||
extra_launcher_call_args = [*map(sympy.sympify, grids[0])]
|
||||
else:
|
||||
|
||||
def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr:
|
||||
if isinstance(expr, sympy.Expr):
|
||||
symbols = [*expr.free_symbols]
|
||||
if not symbols:
|
||||
return expr
|
||||
symbols.sort(key=str)
|
||||
for sym in symbols:
|
||||
if sym in extra_launcher_args:
|
||||
continue
|
||||
extra_launcher_args[sym] = sympy.Symbol(
|
||||
f"_launcher_s{len(extra_launcher_args)}"
|
||||
)
|
||||
return sympy_subs(expr, extra_launcher_args)
|
||||
assert isinstance(expr, int)
|
||||
return sympy.Integer(expr)
|
||||
|
||||
extra_launcher_args: dict[sympy.Symbol, sympy.Symbol] = {}
|
||||
grids = [[*map(rename_sizes_for_launcher, grid)] for grid in grids]
|
||||
|
||||
assert grids and len(grids) == len(configs)
|
||||
precomputed_grids = []
|
||||
for grid, cfg in sorted(
|
||||
zip(grids, configs), key=lambda x: len(x[1].kwargs), reverse=True
|
||||
):
|
||||
precomputed_grids.append(
|
||||
{
|
||||
"config": config_to_dict(cfg),
|
||||
"python": [*map(pexpr, grid)],
|
||||
"cpp": [*map(cexpr, grid)],
|
||||
}
|
||||
)
|
||||
inductor_meta = {
|
||||
"grid_type": PrecomputedGrid.__name__,
|
||||
"precomputed_grids": precomputed_grids,
|
||||
"extra_launcher_args": [*map(str, extra_launcher_args.values())],
|
||||
}
|
||||
extra_launcher_call_args = [*extra_launcher_args.keys()]
|
||||
|
||||
# Distinguish between different functions using function id
|
||||
cache_key: list[Any] = [id(kernel.fn)]
|
||||
cache_key: Any = [id(kernel.fn)]
|
||||
if len(configs) > 0:
|
||||
for arg in kwargs.values():
|
||||
# We need to key on non tensor arg only in autotune mode
|
||||
if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
|
||||
cache_key.append(arg)
|
||||
cache_key.append(str(triton_meta))
|
||||
cache_key.extend(str(inductor_meta))
|
||||
cache_key = tuple(cache_key)
|
||||
|
||||
if cache_key in self.user_defined_kernel_cache:
|
||||
return self.user_defined_kernel_cache[cache_key]
|
||||
return (
|
||||
*self.user_defined_kernel_cache[cache_key],
|
||||
extra_launcher_call_args,
|
||||
)
|
||||
|
||||
name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
|
||||
# Add to the cache for the next use
|
||||
self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
|
||||
|
||||
compile_wrapper = IndentedBuffer()
|
||||
if config.triton.unique_user_kernel_names:
|
||||
|
|
@ -1862,28 +1869,14 @@ class PythonWrapperCodegen(CodeGen):
|
|||
else:
|
||||
compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
|
||||
|
||||
from .triton import gen_common_triton_imports, TritonKernel
|
||||
inductor_meta["kernel_name"] = name
|
||||
inductor_meta.update(TritonKernel.inductor_meta_common())
|
||||
|
||||
compile_wrapper.splice(gen_common_triton_imports())
|
||||
|
||||
inductor_meta = {
|
||||
"kernel_name": name,
|
||||
**TritonKernel.inductor_meta_common(),
|
||||
}
|
||||
|
||||
configs = [
|
||||
{
|
||||
"kwargs": config.kwargs,
|
||||
"num_warps": config.num_warps,
|
||||
"num_stages": config.num_stages,
|
||||
}
|
||||
for config in configs
|
||||
]
|
||||
|
||||
compile_wrapper.splice(
|
||||
f"""
|
||||
@triton_heuristics.user_autotune(
|
||||
configs={configs!r},
|
||||
configs={[*map(config_to_dict, configs)]!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
triton_meta={triton_meta!r},
|
||||
filename=__file__,
|
||||
|
|
@ -1908,7 +1901,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
compile_wrapper.getvalue(),
|
||||
metadata,
|
||||
)
|
||||
return name, triton_meta
|
||||
# Add to the cache for the next use
|
||||
self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
|
||||
return name, triton_meta, extra_launcher_call_args
|
||||
|
||||
def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None):
|
||||
expr = f"{kernel_name}_{tree.prefix}numel"
|
||||
|
|
@ -2021,17 +2016,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
"""
|
||||
)
|
||||
|
||||
def generate_default_grid(
|
||||
self,
|
||||
kernel_name: str,
|
||||
grid_args: list[Any],
|
||||
gpu: bool = True,
|
||||
grid_callable: Optional[Callable[..., Any]] = None,
|
||||
**grid_extra_kwags,
|
||||
):
|
||||
return grid_args
|
||||
|
||||
def prepare_triton_kernel_call(self, device_index, call_args):
|
||||
def prepare_triton_kernel_call(self, call_args):
|
||||
def wrap_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
||||
|
|
@ -2041,13 +2026,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
else:
|
||||
return pexpr(V.graph.sizevars.simplify(arg))
|
||||
|
||||
call_args = [wrap_arg(arg) for arg in call_args]
|
||||
|
||||
if device_index is None:
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
device_index = current_device.index
|
||||
|
||||
return device_index, call_args
|
||||
return [wrap_arg(arg) for arg in call_args]
|
||||
|
||||
def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None):
|
||||
if isinstance(arg_type, torch_dtype):
|
||||
|
|
@ -2144,35 +2123,28 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
grid=None,
|
||||
device_index=None,
|
||||
gpu=True,
|
||||
*,
|
||||
device=None,
|
||||
triton=True,
|
||||
arg_types=None,
|
||||
raw_args=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
autotune_configs=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
"""
|
||||
Generates kernel call code.
|
||||
|
||||
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
|
||||
|
||||
triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True,
|
||||
and C++ when gpu=False.
|
||||
"""
|
||||
if not (triton or gpu):
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
if not (triton or device.type != "cpu"):
|
||||
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
|
||||
return
|
||||
|
||||
device_index, call_args_str = self.prepare_triton_kernel_call(
|
||||
device_index, call_args
|
||||
)
|
||||
call_args_str = self.prepare_triton_kernel_call(call_args)
|
||||
call_args_str = ", ".join(call_args_str)
|
||||
stream_name = PythonWrapperCodegen.write_get_raw_stream(
|
||||
self, device_index, V.graph
|
||||
self, device.index, V.graph
|
||||
)
|
||||
if not triton:
|
||||
stream_ptr = f"c_void_p({stream_name})"
|
||||
|
|
@ -2227,17 +2199,8 @@ class PythonWrapperCodegen(CodeGen):
|
|||
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i)
|
||||
all_args.append(arg_str if key is None else f"{key}={arg_str}")
|
||||
|
||||
if grid is None:
|
||||
grid_str = grid_fn
|
||||
else:
|
||||
grid_str = ", ".join(
|
||||
self.generate_example_arg_value(g, type(g)) for g in grid
|
||||
)
|
||||
if grid_extra_kwargs:
|
||||
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
||||
grid_str = f"{grid_fn}({grid_str})"
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})"
|
||||
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
|
||||
)
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"del {', '.join(arg for arg in tensor_args.values())}\n",
|
||||
|
|
@ -2247,22 +2210,11 @@ class PythonWrapperCodegen(CodeGen):
|
|||
# For cpp wrapper, no need to continue codegen for the main body
|
||||
return
|
||||
|
||||
if grid is None:
|
||||
grid_str = grid_fn
|
||||
else:
|
||||
grid_str = ", ".join(
|
||||
PythonWrapperCodegen._grid_dim_str(self, item) for item in grid
|
||||
)
|
||||
if grid_extra_kwargs:
|
||||
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
||||
grid_str = f"{grid_fn}({grid_str})"
|
||||
# add debug printer code for triton kernel calls at (jit) inductor level
|
||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||
debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
|
||||
with debug_printer_manager:
|
||||
self.writeline(
|
||||
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
|
||||
)
|
||||
self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})")
|
||||
|
||||
def writeline(self, line):
|
||||
self.lines.append(line)
|
||||
|
|
|
|||
|
|
@ -40,25 +40,7 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
|
|||
return source_codes
|
||||
|
||||
def kernel_driver(self) -> str:
|
||||
source_codes = """
|
||||
namespace {
|
||||
|
||||
struct Grid {
|
||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||
uint32_t grid_x;
|
||||
uint32_t grid_y;
|
||||
uint32_t grid_z;
|
||||
|
||||
bool is_non_zero() {
|
||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
"""
|
||||
return source_codes
|
||||
return ""
|
||||
|
||||
def cpp_stream_type(self) -> str:
|
||||
return "sycl::queue*"
|
||||
|
|
|
|||
|
|
@ -1143,8 +1143,6 @@ class _InProcessFxCompile(FxCompile):
|
|||
serialized_extern_kernel_nodes,
|
||||
)
|
||||
|
||||
additional_files = graph.wrapper_code.additional_files
|
||||
|
||||
with dynamo_timed(
|
||||
"AotCodeCompiler.compile", log_pt2_compile_event=True
|
||||
):
|
||||
|
|
@ -1154,7 +1152,11 @@ class _InProcessFxCompile(FxCompile):
|
|||
code,
|
||||
serialized_extern_kernel_nodes,
|
||||
device_type=graph.device_type,
|
||||
additional_files=additional_files,
|
||||
additional_files=[
|
||||
*dict.fromkeys(
|
||||
graph.wrapper_code.additional_files
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
compiled_fn = graph.compile_to_module().call
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from ..pattern_matcher import (
|
|||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
TritonTemplateCaller,
|
||||
)
|
||||
|
|
@ -38,8 +39,9 @@ B2B_GEMM_PASS = PatternMatcherPass(
|
|||
)
|
||||
|
||||
|
||||
def b2b_gemm_grid(M, P, meta):
|
||||
return (ceildiv(M, meta["BLOCK_SIZE_M"]) * ceildiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
|
||||
@SymbolicGridFn
|
||||
def b2b_gemm_grid(M, P, meta, *, cdiv):
|
||||
return (cdiv(M, meta["BLOCK_SIZE_M"]) * cdiv(P, meta["BLOCK_SIZE_P"]), 1, 1)
|
||||
|
||||
|
||||
b2b_gemm_left_template = TritonTemplate(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
|
|
@ -5833,82 +5832,68 @@ class UserDefinedTritonKernel(ExternKernel):
|
|||
) = self.get_kernel_and_metadata()
|
||||
|
||||
# Definition of kernel
|
||||
new_name, triton_meta = wrapper.define_user_defined_triton_kernel(
|
||||
kernel, configs, self.kwargs, restore_value_args, reset_to_zero_args
|
||||
)
|
||||
raw_args = [
|
||||
self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
|
||||
]
|
||||
|
||||
# NOTE: raw_args doesn't include autotuned args.
|
||||
# But, kernel.constexprs includes indices of autotuned args.
|
||||
# So, let's recalculate constexpr indices wrt to raw_args.
|
||||
constexpr_indices = []
|
||||
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
|
||||
if kernel.arg_names.index(kwarg) in kernel.constexprs:
|
||||
constexpr_indices.append(idx)
|
||||
|
||||
# Create a copy of triton_meta to avoid modifying the original version.
|
||||
triton_meta = copy.deepcopy(triton_meta)
|
||||
if not triton_version_uses_attrs_dict():
|
||||
"""
|
||||
Filter out None args.
|
||||
|
||||
see https://github.com/pytorch/pytorch/issues/115344
|
||||
|
||||
Two cases for a None arg:
|
||||
1. The arg is already tl.constexpr, so leave it in
|
||||
2. The arg is not tl.constexpr so we have to remove it
|
||||
"""
|
||||
|
||||
constexpr_indices_set = OrderedSet(constexpr_indices)
|
||||
REMOVED = object()
|
||||
raw_args = [
|
||||
(
|
||||
(idx, arg)
|
||||
if (arg is not None)
|
||||
or (arg is None and idx in constexpr_indices_set)
|
||||
else (idx, REMOVED)
|
||||
)
|
||||
for idx, arg in enumerate(raw_args)
|
||||
]
|
||||
removed_none_args = [idx for idx, val in raw_args if val == REMOVED]
|
||||
raw_args = [val for idx, val in raw_args if val != REMOVED]
|
||||
|
||||
# We have to compute the constexpr indices for the new, filtered raw_args
|
||||
# We also have to adjust equal_to_1.
|
||||
if removed_none_args:
|
||||
eq1_indices_set = OrderedSet[int](triton_meta["configs"][0].equal_to_1)
|
||||
constexpr_indices = []
|
||||
equal_to_1 = []
|
||||
index_shift = 0
|
||||
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
|
||||
# every time we encounter an idx we removed, adjust by one to account for it
|
||||
# So for example if we had [None, const X]
|
||||
# iter 1:
|
||||
# None was removed, adjust=1
|
||||
# iter 2:
|
||||
# X is const at idx=1, but the adjusted idx is 0 now, because None was removed
|
||||
if idx in removed_none_args:
|
||||
index_shift += 1
|
||||
continue
|
||||
arg_index = kernel.arg_names.index(kwarg)
|
||||
if arg_index in kernel.constexprs:
|
||||
constexpr_indices.append(idx - index_shift)
|
||||
if arg_index in eq1_indices_set:
|
||||
equal_to_1.append(idx - index_shift)
|
||||
|
||||
triton_meta["configs"][0].equal_to_1 = equal_to_1
|
||||
|
||||
# Call to kernel
|
||||
self.codegen_comment(wrapper)
|
||||
wrapper.generate_user_defined_triton_kernel(
|
||||
(
|
||||
new_name,
|
||||
raw_args,
|
||||
self.grid,
|
||||
configs,
|
||||
triton_meta,
|
||||
constexpr_indices,
|
||||
extra_launch_args,
|
||||
) = wrapper.define_user_defined_triton_kernel(
|
||||
kernel,
|
||||
configs,
|
||||
self.kwargs,
|
||||
restore_value_args,
|
||||
reset_to_zero_args,
|
||||
self.grid,
|
||||
)
|
||||
named_args = {
|
||||
k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
|
||||
}
|
||||
constexpr_names = OrderedSet([kernel.arg_names[i] for i in kernel.constexprs])
|
||||
|
||||
args: list[Any] = []
|
||||
arg_types: list[Any] = []
|
||||
raw_args_filtered: list[Any] = []
|
||||
for name, arg in itertools.chain(
|
||||
named_args.items(), zip(itertools.repeat(""), extra_launch_args)
|
||||
):
|
||||
raw_args_filtered.append(arg)
|
||||
if isinstance(arg, IRNode):
|
||||
args.append(arg.codegen_reference())
|
||||
arg_types.append(arg.get_dtype())
|
||||
elif isinstance(arg, (int, float, bool, sympy.Expr)):
|
||||
args.append(arg)
|
||||
arg_types.append(type(arg))
|
||||
elif name in constexpr_names:
|
||||
# insert a dummy value for constexpr args of unsupported type
|
||||
# constexprs will end up getting baked into the kernel at compile time
|
||||
args.append(-1)
|
||||
arg_types.append(int)
|
||||
elif arg is None:
|
||||
"""
|
||||
Filter out None args.
|
||||
|
||||
see https://github.com/pytorch/pytorch/issues/115344
|
||||
|
||||
Two cases for a None arg:
|
||||
1. The arg is already tl.constexpr, so leave it in
|
||||
2. The arg is not tl.constexpr so we have to remove it
|
||||
"""
|
||||
if triton_version_uses_attrs_dict():
|
||||
args.append(-1)
|
||||
arg_types.append(int)
|
||||
else:
|
||||
raw_args_filtered.pop()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
|
||||
|
||||
self.codegen_comment(wrapper)
|
||||
wrapper.generate_kernel_call(
|
||||
new_name,
|
||||
args,
|
||||
arg_types=arg_types,
|
||||
raw_args=raw_args_filtered,
|
||||
triton_meta=triton_meta,
|
||||
triton=True,
|
||||
device=self.get_device(),
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from .. import ir, lowering as L
|
|||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
ceildiv as cdiv,
|
||||
use_aten_gemm_kernels,
|
||||
use_ck_gemm_template,
|
||||
use_cpp_bmm_template,
|
||||
|
|
@ -33,7 +33,8 @@ log = logging.getLogger(__name__)
|
|||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def bmm_grid(b, m, n, meta):
|
||||
@SymbolicGridFn
|
||||
def bmm_grid(b, m, n, meta, *, cdiv):
|
||||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ from ..lowering import (
|
|||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
ceildiv,
|
||||
is_ones,
|
||||
is_zeros,
|
||||
pad_listlike,
|
||||
|
|
@ -43,18 +43,20 @@ log = logging.getLogger(__name__)
|
|||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def conv2d_grid(n, c, h, w, meta):
|
||||
@SymbolicGridFn
|
||||
def conv2d_grid(n, c, h, w, meta, *, cdiv):
|
||||
return (
|
||||
ceildiv(n * h * w, meta["BLOCK_M"]),
|
||||
ceildiv(c, meta["BLOCK_N"]),
|
||||
cdiv(n * h * w, meta["BLOCK_M"]),
|
||||
cdiv(c, meta["BLOCK_N"]),
|
||||
meta["GROUPS"],
|
||||
)
|
||||
|
||||
|
||||
def conv3d_grid(n, c, d, h, w, meta):
|
||||
@SymbolicGridFn
|
||||
def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
|
||||
return (
|
||||
ceildiv(n * d * h * w, meta["BLOCK_M"]),
|
||||
ceildiv(c, meta["BLOCK_N"]),
|
||||
cdiv(n * d * h * w, meta["BLOCK_M"]),
|
||||
cdiv(c, meta["BLOCK_N"]),
|
||||
meta["GROUPS"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,12 @@ from ..lowering import (
|
|||
register_lowering,
|
||||
to_dtype,
|
||||
)
|
||||
from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
realize_inputs,
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -91,15 +96,14 @@ def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]):
|
|||
return construct_strides(size, fill_order)
|
||||
|
||||
|
||||
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
|
||||
@SymbolicGridFn
|
||||
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv):
|
||||
"""How is this kernel parallelized?
|
||||
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
|
||||
Each block is responsible for iterating over blocks of keys and values calculating
|
||||
the final attention output.
|
||||
"""
|
||||
import triton
|
||||
|
||||
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
|
||||
return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
|
||||
|
||||
|
||||
def create_placeholder(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from .. import config, ir
|
|||
from ..ir import FixedLayout, FlexibleLayout
|
||||
from ..lowering import empty, empty_strided, lowerings
|
||||
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from ..select_algorithm import autotune_select_algorithm, SymbolicGridFn, TritonTemplate
|
||||
from .flex_attention import (
|
||||
compute_forward_block_mn,
|
||||
compute_forward_inner,
|
||||
|
|
@ -30,6 +30,7 @@ aten = torch.ops.aten
|
|||
prims = torch.ops.prims
|
||||
|
||||
|
||||
@SymbolicGridFn
|
||||
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
|
||||
"""How is this kernel parallelized?
|
||||
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, cast
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.select_algorithm import realize_inputs
|
||||
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -17,7 +17,6 @@ from ..codegen.wrapper import PythonWrapperCodegen
|
|||
from ..ir import ChoiceCaller, Layout
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..utils import (
|
||||
ceildiv as cdiv,
|
||||
get_backend_num_stages,
|
||||
get_num_sms,
|
||||
TMA_DESCRIPTOR_SIZE,
|
||||
|
|
@ -442,14 +441,16 @@ def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool:
|
|||
return fallback_to_aten
|
||||
|
||||
|
||||
def mm_grid(m, n, meta):
|
||||
@SymbolicGridFn
|
||||
def mm_grid(m, n, meta, *, cdiv):
|
||||
"""
|
||||
The CUDA grid size for matmul triton templates.
|
||||
"""
|
||||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
||||
|
||||
|
||||
def persistent_mm_grid(M: int, N: int, meta: dict[str, Any]):
|
||||
@SymbolicGridFn
|
||||
def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
|
||||
"""Defines the grid for persistent kernels."""
|
||||
return (
|
||||
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ from __future__ import annotations
|
|||
|
||||
import builtins
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
|
|
@ -16,7 +18,7 @@ import sys
|
|||
import threading
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._prims_common import compute_required_storage_length
|
||||
|
|
@ -71,7 +73,7 @@ class NoTritonConfigsError(RuntimeError):
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Hashable
|
||||
from collections.abc import Container, Hashable, Sequence
|
||||
|
||||
LauncherType = Any
|
||||
|
||||
|
|
@ -136,7 +138,7 @@ def disable_pointwise_autotuning(inductor_meta):
|
|||
return not inductor_meta.get("autotune_pointwise", True)
|
||||
|
||||
|
||||
def _dump_launch_params(args, kwargs, launcher, kernel_name):
|
||||
def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
||||
call_args = []
|
||||
call_kwargs = {}
|
||||
for arg in args:
|
||||
|
|
@ -154,14 +156,12 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name):
|
|||
call_kwargs[k] = v
|
||||
call_kwargs["num_warps"] = launcher.config.num_warps
|
||||
call_kwargs["num_stages"] = launcher.config.num_stages
|
||||
args_str = ""
|
||||
args_str += ", ".join(call_args)
|
||||
for k, v in call_kwargs.items():
|
||||
args_str += f", {k}={v}"
|
||||
|
||||
args_str = [*call_args]
|
||||
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
|
||||
args_str = ", ".join(args_str)
|
||||
abs_path = os.path.abspath(sys.argv[0])
|
||||
with open(f"{abs_path}.launch_params", "a") as f:
|
||||
f.write(f"{kernel_name} | {args_str}\n")
|
||||
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
|
||||
|
||||
|
||||
class CachingAutotuner(KernelInterface):
|
||||
|
|
@ -478,6 +478,12 @@ class CachingAutotuner(KernelInterface):
|
|||
if k in cfg_kwargs:
|
||||
compile_meta[k] = cfg_kwargs.pop(k)
|
||||
compile_meta["constants"].update(cfg_kwargs)
|
||||
for i in self.fn.constexprs:
|
||||
arg_name = self.fn.arg_names[i]
|
||||
if arg_name not in compile_meta["constants"] and (
|
||||
arg_name == "num_warps" or arg_name == "num_stages"
|
||||
):
|
||||
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
compile_meta["debug"] = self.inductor_meta.get(
|
||||
|
|
@ -563,7 +569,7 @@ class CachingAutotuner(KernelInterface):
|
|||
return new_args
|
||||
return args
|
||||
|
||||
def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
|
||||
def bench(self, launcher, *args, with_profiler=False, **kwargs):
|
||||
"""Measure the performance of a given launcher"""
|
||||
# we don't skip configs with spilled registers when auto-tuning custom
|
||||
# (user-written) Triton kernels, as (i) we don't have any knowledge or
|
||||
|
|
@ -595,7 +601,6 @@ class CachingAutotuner(KernelInterface):
|
|||
launcher(
|
||||
*args_with_constexprs,
|
||||
**cloned_kwargs,
|
||||
grid=grid,
|
||||
stream=stream,
|
||||
)
|
||||
self.restore_args_from_cpu(cpu_copies)
|
||||
|
|
@ -650,8 +655,8 @@ class CachingAutotuner(KernelInterface):
|
|||
else:
|
||||
budget -= size
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
maybe_copy(self.fn.arg_names[i], arg)
|
||||
for name, arg in zip(self.fn.arg_names, args):
|
||||
maybe_copy(name, arg)
|
||||
|
||||
for name, arg in kwargs.items():
|
||||
maybe_copy(name, arg)
|
||||
|
|
@ -712,10 +717,10 @@ class CachingAutotuner(KernelInterface):
|
|||
return arg
|
||||
|
||||
cloned_args = [
|
||||
prepare_arg(self.fn.arg_names[i], arg) for i, arg in enumerate(args)
|
||||
prepare_arg(name, arg)
|
||||
for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args)
|
||||
]
|
||||
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
|
||||
|
||||
return cloned_args, cloned_kwargs
|
||||
|
||||
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
||||
|
|
@ -777,12 +782,7 @@ class CachingAutotuner(KernelInterface):
|
|||
if self.save_cache_hook:
|
||||
self.save_cache_hook(launcher.config, self.autotune_time_taken_ns)
|
||||
|
||||
def save_gpu_kernel(self, grid, stream, launcher):
|
||||
if callable(grid):
|
||||
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
|
||||
else:
|
||||
grid_x, grid_y, grid_z = grid
|
||||
|
||||
def save_gpu_kernel(self, stream, launcher):
|
||||
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
|
||||
assert key is not None, "kernel_name can not be None"
|
||||
params = {
|
||||
|
|
@ -791,13 +791,6 @@ class CachingAutotuner(KernelInterface):
|
|||
if hasattr(launcher.bin.metadata, "name")
|
||||
else launcher.bin.metadata["name"]
|
||||
),
|
||||
"grid_x": grid_x,
|
||||
"grid_y": grid_y,
|
||||
"grid_z": grid_z,
|
||||
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
|
||||
"y_block": launcher.config.kwargs.get("YBLOCK", None),
|
||||
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
|
||||
"r_block": launcher.config.kwargs.get("RBLOCK", None),
|
||||
"num_warps": (
|
||||
launcher.bin.num_warps
|
||||
if hasattr(launcher.bin, "num_warps")
|
||||
|
|
@ -810,7 +803,11 @@ class CachingAutotuner(KernelInterface):
|
|||
),
|
||||
"stream": stream,
|
||||
# User defined triton kernels will have arbitrary kwarg names
|
||||
"meta": launcher.config.kwargs,
|
||||
"config": config_to_dict(launcher.config),
|
||||
"inductor_meta": self.inductor_meta,
|
||||
"triton_meta": self.triton_meta,
|
||||
"def_args": launcher.def_args,
|
||||
"call_args": launcher.call_args,
|
||||
}
|
||||
from torch._inductor.codecache import CudaKernelParamCache
|
||||
|
||||
|
|
@ -888,8 +885,15 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
return config2launcher.get(best_config)
|
||||
|
||||
def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
stream,
|
||||
benchmark_run=False,
|
||||
**kwargs,
|
||||
): # type:ignore[override]
|
||||
if self.triton_interpret:
|
||||
args, grid = self._interpret_args_grid(args, self.configs[0])
|
||||
return self.fn[grid](
|
||||
*args,
|
||||
**kwargs,
|
||||
|
|
@ -902,35 +906,28 @@ class CachingAutotuner(KernelInterface):
|
|||
self.precompile()
|
||||
self.precompile_time_taken_ns = time.time_ns() - start_time
|
||||
if len(self.launchers) > 1:
|
||||
self.autotune_to_one_config(*args, grid=grid, **kwargs)
|
||||
self.autotune_to_one_config(*args, **kwargs)
|
||||
|
||||
if not getattr(
|
||||
self.launchers[0].config, "found_by_coordesc", False
|
||||
) and self.inductor_meta.get("coordinate_descent_tuning", False):
|
||||
self.launchers = [
|
||||
self.coordinate_descent_tuning(
|
||||
self.launchers[0], *args, grid=grid, **kwargs
|
||||
)
|
||||
self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs)
|
||||
]
|
||||
|
||||
(launcher,) = self.launchers
|
||||
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
|
||||
self.save_gpu_kernel(grid, stream, launcher)
|
||||
self.save_gpu_kernel(stream, launcher)
|
||||
|
||||
args = self._get_args_with_constexprs(args, launcher)
|
||||
|
||||
if self.dump_launch_params:
|
||||
_dump_launch_params(args, kwargs, launcher, self.fn.__name__)
|
||||
new_args, grid = self._interpret_args_grid(args, launcher.config)
|
||||
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
|
||||
|
||||
# it is faster than entering and exiting a context manager, even if the context
|
||||
# manager is a nullcontext.
|
||||
if autograd_profiler._is_profiler_enabled:
|
||||
# grid can be a tuple of ints or a string.
|
||||
if isinstance(grid, tuple):
|
||||
grid_info = str(grid)
|
||||
else:
|
||||
grid_info = getattr(grid, "grid_fn_str", "")
|
||||
|
||||
kernel_kwargs_str = ",".join(
|
||||
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
|
||||
)
|
||||
|
|
@ -939,7 +936,6 @@ class CachingAutotuner(KernelInterface):
|
|||
"kernel_file": (self.filename or ""),
|
||||
"kernel_hash": self.kernel_hash,
|
||||
"kernel_backend": "triton",
|
||||
"grid": grid_info,
|
||||
"stream": stream,
|
||||
"num_warps": launcher.config.num_warps,
|
||||
"num_stages": launcher.config.num_stages,
|
||||
|
|
@ -954,17 +950,33 @@ class CachingAutotuner(KernelInterface):
|
|||
return launcher(
|
||||
*args,
|
||||
**kwargs,
|
||||
grid=grid,
|
||||
stream=stream,
|
||||
)
|
||||
else:
|
||||
return launcher(
|
||||
*args,
|
||||
**kwargs,
|
||||
grid=grid,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _interpret_args_grid(
|
||||
self, args: tuple[Any, ...], cfg: Config
|
||||
) -> tuple[tuple[Any, ...], tuple[int, int, int]]:
|
||||
grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow(
|
||||
dict(
|
||||
zip(
|
||||
[
|
||||
*self.fn.arg_names,
|
||||
*self.inductor_meta.get("extra_launcher_args", ()),
|
||||
],
|
||||
args,
|
||||
)
|
||||
)
|
||||
)
|
||||
if self.inductor_meta.get("extra_launcher_args"):
|
||||
args = args[: -len(self.inductor_meta["extra_launcher_args"])]
|
||||
return args, grid
|
||||
|
||||
|
||||
class _ConstRepr:
|
||||
def __init__(self, value: str):
|
||||
|
|
@ -1110,10 +1122,11 @@ class TritonCompileResult:
|
|||
for i, arg in enumerate(fn.arg_names)
|
||||
if i not in fn.constexprs and arg not in none_args
|
||||
]
|
||||
cfg_dict = config_to_dict(cfg)
|
||||
def_args = [
|
||||
name
|
||||
for name in fn.arg_names
|
||||
if name not in cfg.kwargs and name not in none_args
|
||||
if name not in cfg_dict and name not in none_args
|
||||
]
|
||||
|
||||
binary_shared = (
|
||||
|
|
@ -1176,9 +1189,7 @@ class TritonCompileResult:
|
|||
# we want to burn None in to the launch args with zero overhead.
|
||||
# See https://github.com/pytorch/pytorch/issues/123597
|
||||
if binary.__class__.launch_enter_hook:
|
||||
launch_metadata = (
|
||||
f"bin.launch_metadata(grid, stream, {', '.join(call_args)})"
|
||||
)
|
||||
launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})"
|
||||
else:
|
||||
launch_metadata = "None"
|
||||
runner_args = [
|
||||
|
|
@ -1194,18 +1205,20 @@ class TritonCompileResult:
|
|||
*call_args,
|
||||
]
|
||||
|
||||
exec(
|
||||
f"""
|
||||
def launcher({", ".join(def_args)}, grid, stream):
|
||||
if callable(grid):
|
||||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||
else:
|
||||
grid_0, grid_1, grid_2 = grid
|
||||
runner({", ".join(runner_args)})
|
||||
return bin
|
||||
""".lstrip(),
|
||||
scope,
|
||||
)
|
||||
if "extra_launcher_args" in self.inductor_meta:
|
||||
def_args.extend(self.inductor_meta["extra_launcher_args"])
|
||||
|
||||
grid = GridExpr.from_meta(self.inductor_meta, cfg)
|
||||
# grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)`
|
||||
lines = [
|
||||
f"def launcher({', '.join(def_args)}, stream):",
|
||||
*[f" {line}" for line in grid.prefix],
|
||||
f" grid_0 = {grid.x_grid}",
|
||||
f" grid_1 = {grid.y_grid}",
|
||||
f" grid_2 = {grid.z_grid}",
|
||||
f" runner({', '.join(runner_args)})",
|
||||
]
|
||||
exec("\n".join(lines), scope)
|
||||
|
||||
launcher = scope["launcher"]
|
||||
launcher.config = cfg
|
||||
|
|
@ -1217,6 +1230,8 @@ class TritonCompileResult:
|
|||
if launcher.store_cubin:
|
||||
launcher.fn = fn
|
||||
launcher.bin = binary
|
||||
launcher.def_args = def_args
|
||||
launcher.call_args = call_args
|
||||
return launcher
|
||||
|
||||
|
||||
|
|
@ -1306,9 +1321,9 @@ class DebugAutotuner(CachingAutotuner):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.cached = None
|
||||
|
||||
def run(self, *args, grid, stream, **kwargs):
|
||||
def run(self, *args, stream, **kwargs):
|
||||
if not self.with_bandwidth_info:
|
||||
super().run(*args, grid=grid, stream=stream, **kwargs, benchmark_run=True)
|
||||
super().run(*args, stream=stream, **kwargs, benchmark_run=True)
|
||||
return
|
||||
else:
|
||||
possible_names = _find_names(self)
|
||||
|
|
@ -1322,16 +1337,14 @@ class DebugAutotuner(CachingAutotuner):
|
|||
self.precompile()
|
||||
self.precompile_time_taken_ns = time.time_ns() - start_time
|
||||
if len(self.launchers) > 1:
|
||||
self.autotune_to_one_config(*args, grid=grid, **kwargs)
|
||||
self.autotune_to_one_config(*args, **kwargs)
|
||||
(launcher,) = self.launchers
|
||||
|
||||
if launcher.store_cubin:
|
||||
self.save_gpu_kernel(grid, stream, launcher)
|
||||
self.save_gpu_kernel(stream, launcher)
|
||||
|
||||
if self.cached is None:
|
||||
ms = self.bench(
|
||||
launcher, *args, grid=grid, with_profiler=self.with_profiler
|
||||
)
|
||||
ms = self.bench(launcher, *args, with_profiler=self.with_profiler)
|
||||
num_in_out_ptrs = len(
|
||||
[
|
||||
arg_name
|
||||
|
|
@ -2127,6 +2140,19 @@ def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
|||
return popped
|
||||
|
||||
|
||||
def config_to_dict(config: Config) -> dict[str, Any]:
|
||||
return {
|
||||
**config.kwargs,
|
||||
"num_warps": config.num_warps,
|
||||
"num_stages": config.num_stages,
|
||||
}
|
||||
|
||||
|
||||
def config_from_dict(config: dict[str, Any]) -> Config:
|
||||
config = {**config}
|
||||
return Config(config, **_pop_config_kwargs(config))
|
||||
|
||||
|
||||
def fixed_config(config, filename, triton_meta, inductor_meta):
|
||||
"""
|
||||
Used when the configuration is already decided at compile time
|
||||
|
|
@ -2151,10 +2177,7 @@ def user_autotune(
|
|||
if len(configs) == 0:
|
||||
configs = [triton.Config({})]
|
||||
else:
|
||||
configs = [
|
||||
triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c}))
|
||||
for c in configs
|
||||
]
|
||||
configs = [*map(config_from_dict, configs)]
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
|
|
@ -2180,150 +2203,226 @@ def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
|||
)
|
||||
|
||||
|
||||
def grid(*numels):
|
||||
"""Helper function to compute triton grids"""
|
||||
if len(numels) == 1:
|
||||
xnumel, ynumel, znumel = numels[0], None, None
|
||||
elif len(numels) == 2:
|
||||
xnumel, ynumel, znumel = numels[1], numels[0], None
|
||||
elif len(numels) == 3:
|
||||
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
|
||||
else:
|
||||
raise AssertionError(f"invalid size for numels {len(numels)}")
|
||||
@dataclasses.dataclass
|
||||
class GridExpr:
|
||||
"""Generate code for grid size expressions in launcher"""
|
||||
|
||||
def get_grid_dim(numel, block):
|
||||
if numel is None:
|
||||
return 1
|
||||
if block is None:
|
||||
inductor_meta: dict[str, Any]
|
||||
mode: Literal["python", "cpp"] = "python"
|
||||
prefix: Sequence[str] = ()
|
||||
x_grid: Union[str, int] = 1
|
||||
y_grid: Union[str, int] = 1
|
||||
z_grid: Union[str, int] = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.mode in ("python", "cpp")
|
||||
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def ceildiv(
|
||||
self, numel: Union[str, int], block: Union[None, int, str]
|
||||
) -> Union[str, int]:
|
||||
if block is None or block == 1:
|
||||
return numel
|
||||
return ceildiv(numel, block)
|
||||
if isinstance(numel, int) and isinstance(block, int):
|
||||
return ceildiv(numel, block) # constant fold
|
||||
if self.mode == "python":
|
||||
return f"-(({numel}) // -({block}))"
|
||||
# trick above doesn't work in C++ due to rounding differences
|
||||
return f"(({numel} + ({block} - 1)) / ({block}))"
|
||||
|
||||
def grid_fn(meta):
|
||||
x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
|
||||
y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
|
||||
def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
||||
"""Codegen for max function with constant folding, constants are represented as int"""
|
||||
items = self._constant_fold(max, seq)
|
||||
if len(items) <= 1:
|
||||
return items[0]
|
||||
if self.mode == "python":
|
||||
return f"max({', '.join(map(str, items))})"
|
||||
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
|
||||
|
||||
max_y_grid = get_max_y_grid()
|
||||
if znumel is None:
|
||||
div = ceildiv(y_grid, max_y_grid)
|
||||
y_grid = ceildiv(y_grid, div)
|
||||
z_grid = div
|
||||
else:
|
||||
z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
|
||||
torch._check(
|
||||
y_grid <= max_y_grid,
|
||||
lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
|
||||
)
|
||||
def summation(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
||||
"""Codegen for sum function with constant folding, constants are represented as int"""
|
||||
items = self._constant_fold(sum, seq)
|
||||
if len(items) <= 1:
|
||||
return items[0]
|
||||
return " + ".join(map(str, items))
|
||||
|
||||
return (
|
||||
x_grid,
|
||||
y_grid,
|
||||
z_grid,
|
||||
def _constant_fold(
|
||||
self, fn: Callable[[list[int]], int], seq: list[Union[int, str]]
|
||||
) -> list[Union[int, str]]:
|
||||
"""Constant fold through a commutative fn where ints are constants"""
|
||||
items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)]
|
||||
const_items = [x for x in seq if isinstance(x, int)]
|
||||
if const_items:
|
||||
items.append(fn(const_items))
|
||||
return items
|
||||
|
||||
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
|
||||
# Grid functions are one per kernel, so name collisions are fine
|
||||
if self.mode == "python":
|
||||
return f"{name} = {expr}"
|
||||
if self.mode == "cpp":
|
||||
return f"uint32_t {name} = {expr};"
|
||||
raise AssertionError(f"invalid mode {self.mode}")
|
||||
|
||||
@staticmethod
|
||||
def from_meta(
|
||||
inductor_meta: dict[str, Any],
|
||||
cfg: Union[Config, dict[str, int]],
|
||||
mode: Literal["python", "cpp"] = "python",
|
||||
) -> GridExpr:
|
||||
grid_cls = globals()[inductor_meta["grid_type"]]
|
||||
assert issubclass(grid_cls, GridExpr)
|
||||
grid = grid_cls(inductor_meta=inductor_meta, mode=mode)
|
||||
if isinstance(cfg, Config):
|
||||
cfg = config_to_dict(cfg)
|
||||
grid.generate(cfg)
|
||||
return grid
|
||||
|
||||
def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]:
|
||||
scope = {**meta}
|
||||
for line in self.prefix:
|
||||
exec(line, scope)
|
||||
exec(f"grid_0 = {self.x_grid}", scope)
|
||||
exec(f"grid_1 = {self.y_grid}", scope)
|
||||
exec(f"grid_2 = {self.z_grid}", scope)
|
||||
return scope["grid_0"], scope["grid_1"], scope["grid_2"]
|
||||
|
||||
|
||||
class Grid1D(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
|
||||
|
||||
class Grid2D(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
|
||||
|
||||
|
||||
class Grid3D(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
|
||||
self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK"))
|
||||
|
||||
|
||||
class Grid2DWithYZOverflow(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
self.prefix = [
|
||||
self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))),
|
||||
self.assign_tmp(
|
||||
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
||||
),
|
||||
]
|
||||
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
|
||||
self.z_grid = "y_grid_div_"
|
||||
|
||||
|
||||
class CooperativeReductionGrid(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = str(meta["RSPLIT"])
|
||||
self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
|
||||
|
||||
class SplitScanGrid(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
assert meta.get("XBLOCK", 1) == 1
|
||||
self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK"))
|
||||
self.y_grid = "xnumel"
|
||||
|
||||
|
||||
class FixedGrid(GridExpr):
|
||||
@staticmethod
|
||||
def setup_grid_as_args() -> dict[str, Any]:
|
||||
"""Inductor meta so the launcher takes three extra grid arguments"""
|
||||
return {
|
||||
"grid_type": FixedGrid.__name__,
|
||||
"fixed_grid": ["_grid_0", "_grid_1", "_grid_2"],
|
||||
"extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"],
|
||||
}
|
||||
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"]
|
||||
|
||||
|
||||
class PrecomputedGrid(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
for candidate in self.inductor_meta["precomputed_grids"]:
|
||||
if all(meta.get(k) == v for k, v in candidate["config"].items()):
|
||||
self.x_grid, self.y_grid, self.z_grid = candidate[self.mode]
|
||||
return
|
||||
raise AssertionError(
|
||||
f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}"
|
||||
)
|
||||
|
||||
setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010
|
||||
|
||||
return grid_fn
|
||||
|
||||
|
||||
def cooperative_reduction_grid(xnumel):
|
||||
def grid_fn(meta):
|
||||
return (meta["RSPLIT"], ceildiv(xnumel, meta.get("XBLOCK", 1)), 1)
|
||||
|
||||
grid_fn_str = f"cooperative_reduction_grid({xnumel})"
|
||||
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
||||
return grid_fn
|
||||
|
||||
|
||||
def maybe_cooperative_reduction_grid(xnumel):
|
||||
def grid_fn(meta):
|
||||
if "RSPLIT" in meta:
|
||||
return coop_grid(meta)
|
||||
return normal_grid(meta)
|
||||
|
||||
coop_grid = cooperative_reduction_grid(xnumel)
|
||||
normal_grid = grid(xnumel)
|
||||
grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
|
||||
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
||||
return grid_fn
|
||||
|
||||
|
||||
def split_scan_grid(xnumel, rnumel):
|
||||
def grid_fn(meta):
|
||||
assert meta.get("XBLOCK", 1) == 1
|
||||
return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
|
||||
|
||||
grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
|
||||
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
||||
|
||||
return grid_fn
|
||||
|
||||
|
||||
def grid_combo_kernels(
|
||||
*numels, num_kernels, min_blocks, is_sequential, default_meta=None
|
||||
):
|
||||
"""min_blocks is the minimal size of the grid x dimension"""
|
||||
if not is_sequential:
|
||||
# round robin dispatch
|
||||
numels_agg = list(numels)
|
||||
for i in range(len(numels_agg)):
|
||||
if isinstance(numels_agg[i], (list, tuple)):
|
||||
numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301
|
||||
kernel_grid_fn = grid(*numels_agg)
|
||||
|
||||
if isinstance(numels[-1], (list, tuple)):
|
||||
min_blocks_d = max(-min(numels[-1]), 0) * num_kernels
|
||||
else:
|
||||
min_blocks_d = None
|
||||
if min_blocks is None:
|
||||
assert min_blocks_d is not None
|
||||
min_blocks = min_blocks_d
|
||||
else:
|
||||
assert min_blocks_d is None or min_blocks == min_blocks_d, (
|
||||
f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
|
||||
class ComboKernelGrid(GridExpr):
|
||||
def generate(self, meta: dict[str, int]):
|
||||
combo_meta = self.inductor_meta["combo_grid_meta"]
|
||||
if combo_meta["default_config"]:
|
||||
meta = {**combo_meta["default_config"], **meta}
|
||||
no_x_dims = []
|
||||
xnumels = []
|
||||
ynumels = []
|
||||
znumels = []
|
||||
for num in range(combo_meta["num_kernels"]):
|
||||
assert (
|
||||
combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0
|
||||
)
|
||||
else:
|
||||
# sequential dispatch
|
||||
seq_numels = list(numels)
|
||||
# x numels are not used here, just a place holder
|
||||
seq_numels[-1] = 1024
|
||||
for i in range(len(seq_numels) - 1):
|
||||
if isinstance(seq_numels[i], (list, tuple)):
|
||||
seq_numels[i] = max(seq_numels[i])
|
||||
no_x_dims.append(combo_meta[f"no_x_dim_{num}"])
|
||||
xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}")
|
||||
if f"ynumel_{num}" in combo_meta:
|
||||
ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}")
|
||||
if f"znumel_{num}" in combo_meta:
|
||||
znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}")
|
||||
|
||||
kernel_grid_fn = grid(*seq_numels)
|
||||
self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta)
|
||||
if combo_meta["min_blocks"]:
|
||||
self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]])
|
||||
if ynumels:
|
||||
self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK"))
|
||||
if znumels:
|
||||
self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK"))
|
||||
|
||||
def get_grid_dim(numel, block):
|
||||
if numel is None:
|
||||
return 1
|
||||
if block is None:
|
||||
return numel
|
||||
return ceildiv(numel, block)
|
||||
def combo_x_grid(
|
||||
self,
|
||||
xnumels: list[Union[int, str]],
|
||||
no_x_dims: list[bool],
|
||||
meta: dict[str, int],
|
||||
) -> Union[str, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def grid_fn(meta):
|
||||
assert min_blocks is not None, "min_blocks must be a number"
|
||||
cuda_grid = list(kernel_grid_fn(meta))
|
||||
cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
|
||||
return tuple(cuda_grid)
|
||||
|
||||
def seq_grid_fn(meta):
|
||||
cuda_grid = list(kernel_grid_fn(meta))
|
||||
# x <= 0 means this kernel's x grid is not tunable (x_no_dim is true)
|
||||
x_grid = sum(
|
||||
class SequentialComboKernelGrid(ComboKernelGrid):
|
||||
def combo_x_grid(
|
||||
self,
|
||||
xnumels: list[Union[int, str]],
|
||||
no_x_dims: list[bool],
|
||||
meta: dict[str, int],
|
||||
) -> Union[str, int]:
|
||||
assert len(xnumels) == len(no_x_dims)
|
||||
return self.summation(
|
||||
[
|
||||
-x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1))
|
||||
for x in numels[-1]
|
||||
self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK"))
|
||||
for x, no_x_dim in zip(xnumels, no_x_dims)
|
||||
]
|
||||
)
|
||||
cuda_grid[0] = x_grid
|
||||
return tuple(cuda_grid)
|
||||
|
||||
def grid_fn_default_meta(meta):
|
||||
return grid_fn(default_meta)
|
||||
|
||||
def seq_grid_fn_default_meta(meta):
|
||||
return seq_grid_fn(default_meta)
|
||||
|
||||
if default_meta is None:
|
||||
return grid_fn if not is_sequential else seq_grid_fn
|
||||
else:
|
||||
return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta
|
||||
class RoundRobinComboKernelGrid(ComboKernelGrid):
|
||||
def combo_x_grid(
|
||||
self,
|
||||
xnumels: list[Union[int, str]],
|
||||
no_x_dims: list[bool],
|
||||
meta: dict[str, int],
|
||||
) -> str:
|
||||
assert len(xnumels) == len(no_x_dims)
|
||||
num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"]
|
||||
exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim]
|
||||
xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim]
|
||||
if xnumels_x_dim:
|
||||
exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK")))
|
||||
return f"({self.maximum(exprs)}) * {num_kernels}"
|
||||
|
|
|
|||
|
|
@ -3863,7 +3863,7 @@ class Scheduler:
|
|||
for name in sorted(
|
||||
self.buffer_names_to_free
|
||||
- V.graph.removed_buffers
|
||||
- V.graph.wrapper_code.freed
|
||||
- V.graph.wrapper_code.freed # type: ignore[has-type]
|
||||
):
|
||||
if name in self.name_to_buf:
|
||||
buf = self.name_to_buf[name]
|
||||
|
|
@ -4122,7 +4122,7 @@ class Scheduler:
|
|||
V.graph.wrapper_code.define_subgraph_launcher_fn(partition_code)
|
||||
|
||||
V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature)
|
||||
V.graph.wrapper_code.allocated.update(
|
||||
V.graph.wrapper_code.allocated.update( # type: ignore[has-type]
|
||||
[node.get_name() for node in signature.output_nodes]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from torch._inductor.utils import clear_on_fresh_inductor_cache
|
|||
from torch.utils._filelock import FileLock
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..utils._sympy.functions import CeilDiv
|
||||
from . import config, ir
|
||||
from .autotune_process import (
|
||||
TensorMeta,
|
||||
|
|
@ -54,12 +55,15 @@ from .codegen.triton import (
|
|||
TritonScheduling,
|
||||
)
|
||||
from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta
|
||||
from .codegen.wrapper import pexpr
|
||||
from .exc import CUDACompileError
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .ops_handler import StoreMode
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import DeviceProperties
|
||||
from .runtime.triton_heuristics import FixedGrid
|
||||
from .utils import (
|
||||
ceildiv,
|
||||
FakeIndentedBuffer,
|
||||
get_dtype_size,
|
||||
is_gpu,
|
||||
|
|
@ -442,6 +446,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
inductor_meta = {
|
||||
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
||||
**TritonKernel.inductor_meta_common(),
|
||||
**FixedGrid.setup_grid_as_args(),
|
||||
}
|
||||
if config.profile_bandwidth or config.benchmark_kernel:
|
||||
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
||||
|
|
@ -988,46 +993,44 @@ class TritonTemplateKernel(TritonKernel):
|
|||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
|
||||
# Handle workspace allocation
|
||||
if self.workspace_arg is not None:
|
||||
wrapper.generate_workspace_allocation(self.workspace_arg)
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
# In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
|
||||
# if any dynamic dimension is involved. We rely on the Python version
|
||||
# of the grid function to generate those grid configs, which may contain
|
||||
# symbolic values. The wrapper will use cexpr to print out C++ code
|
||||
# appropriately for the grid configs.
|
||||
grid = self.call_sizes + [self.meta]
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
grid=self.grid_fn(*grid),
|
||||
# Calling self.grid_fn(*grid) already computes grid as a tuple,
|
||||
# so we need to explicitly set grid_fn as empty here. Otherwise, the
|
||||
# generated wrapper code will wrap the tuple as grid(tuple), which can
|
||||
# cause incorrect grid computation in some corner cases.
|
||||
grid_fn="",
|
||||
arg_types=arg_types,
|
||||
triton_meta=self.triton_meta,
|
||||
)
|
||||
grid_args = ()
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
else:
|
||||
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
|
||||
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
|
||||
meta = wrapper.add_meta_once(self.meta)
|
||||
grid = self.call_sizes + [meta]
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
grid=grid,
|
||||
grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
|
||||
arg_types=arg_types,
|
||||
triton_meta=self.triton_meta,
|
||||
gpu="cpu" not in V.graph.device_types,
|
||||
fn_name = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}"
|
||||
call_args.append(
|
||||
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
|
||||
)
|
||||
arg_types.append(None)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
call_args.extend(grid_args)
|
||||
arg_types.extend(map(type, grid_args))
|
||||
|
||||
if self.workspace_arg is not None:
|
||||
wrapper.generate_workspace_allocation(self.workspace_arg)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
arg_types=arg_types,
|
||||
triton_meta=self.triton_meta,
|
||||
triton=True,
|
||||
)
|
||||
if self.workspace_arg is not None:
|
||||
wrapper.generate_workspace_deallocation(self.workspace_arg)
|
||||
|
||||
def kernel_benchmark_extra_args(self) -> list[str]:
|
||||
return [
|
||||
str(x)
|
||||
for x in self.grid_fn(
|
||||
*V.graph.sizevars.size_hints(self.call_sizes), self.meta
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _jinja2_env():
|
||||
|
|
@ -1210,8 +1213,7 @@ class TritonTemplate(KernelTemplate):
|
|||
module_path=mod.__file__,
|
||||
module_cache_key=mod.key,
|
||||
kernel_name=kernel_name,
|
||||
grid=grid,
|
||||
extra_args=extra_args,
|
||||
extra_args=[*extra_args, *grid],
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
||||
|
|
@ -1328,7 +1330,6 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
|||
self.log_info.update(
|
||||
{
|
||||
"backend": "Triton",
|
||||
"grid": str(self.bmreq.grid),
|
||||
"num_stages": self.bmreq.num_stages,
|
||||
"num_warps": self.bmreq.num_warps,
|
||||
}
|
||||
|
|
@ -2352,5 +2353,35 @@ def realize_inputs(*args):
|
|||
return [realize_inputs(x) for x in args]
|
||||
|
||||
|
||||
class SymbolicGridFn:
|
||||
"""
|
||||
Wrapper around a grid function that allows either int or sympy inputs.
|
||||
|
||||
@SymbolicGridFn
|
||||
def grid(x, meta, *, cdiv):
|
||||
return cdiv(x, meta["BLOCK_X"])
|
||||
"""
|
||||
|
||||
def __init__(self, fn: Callable[..., tuple[Any, Any, Any]]):
|
||||
self.fn = fn
|
||||
self.kwargs_int = {}
|
||||
self.kwargs_sym = {}
|
||||
params = inspect.signature(fn).parameters
|
||||
for name, fn_sym, fn_int in [
|
||||
("cdiv", CeilDiv, ceildiv),
|
||||
("min", sympy.Min, min),
|
||||
("max", sympy.Max, max),
|
||||
]:
|
||||
if name in params:
|
||||
self.kwargs_int[name] = fn_int
|
||||
self.kwargs_sym[name] = fn_sym
|
||||
|
||||
def __call__(self, *args, **kwargs) -> tuple[int, int, int]:
|
||||
return self.fn(*args, **kwargs, **self.kwargs_int)
|
||||
|
||||
def sympy_call(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs, **self.kwargs_sym)
|
||||
|
||||
|
||||
# ensure lowering is imported so that `extern_kernels.*` is populated
|
||||
from . import lowering # noqa: F401
|
||||
|
|
|
|||
|
|
@ -45,23 +45,23 @@ def rename_kernels(source_code: str) -> str:
|
|||
|
||||
|
||||
def merge_params(original_params: list[str], new_params: list[str]) -> list[str]:
|
||||
assert len(new_params) >= len(original_params)
|
||||
for idx in range(len(new_params)):
|
||||
if new_params[idx] == "T":
|
||||
new_params[idx] = original_params[idx]
|
||||
return new_params
|
||||
|
||||
|
||||
def add_launch_params(original: str, kernel_to_params: dict[str, str]) -> str:
|
||||
def add_launch_params(
|
||||
original: str, kernel_to_params: dict[str, tuple[str, str]]
|
||||
) -> str:
|
||||
# Regex to match the function call in the original string
|
||||
pattern = r"(\w+)\.run\((.*), grid=(.*\)), [^)]*\)"
|
||||
pattern = r"(\w+)\.run\((.*)\)"
|
||||
|
||||
def replace(match) -> str:
|
||||
# Extract parts from the regex match
|
||||
func_name = match.group(1)
|
||||
params = match.group(2)
|
||||
grid = match.group(3)
|
||||
new_params = kernel_to_params[func_name]
|
||||
new_params, grid = kernel_to_params[func_name]
|
||||
new_params = merge_params(params.split(", "), new_params.split(", "))
|
||||
|
||||
# Format the new function call
|
||||
|
|
@ -103,9 +103,8 @@ def process_file(input_filename: str, output_filename: str) -> str:
|
|||
launch_params_meta = f.readlines()
|
||||
|
||||
split_params = [i.split("|") for i in launch_params_meta]
|
||||
strip_params = [[a.strip(), b.strip()] for a, b in split_params]
|
||||
kernel_to_args: dict[str, str] = dict(strip_params)
|
||||
transformed_code = add_launch_params(transformed_code, kernel_to_args)
|
||||
kernel_args_grid = {a.strip(): (b.strip(), c.strip()) for a, b, c in split_params}
|
||||
transformed_code = add_launch_params(transformed_code, kernel_args_grid)
|
||||
|
||||
with open(output_filename, "w") as file:
|
||||
file.write(transformed_code)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user