diff --git a/test/inductor/test_cutedsl_template.py b/test/inductor/test_cutedsl_template.py new file mode 100644 index 00000000000..4e9fcd13287 --- /dev/null +++ b/test/inductor/test_cutedsl_template.py @@ -0,0 +1,319 @@ +# Owner(s): ["module: inductor"] +import unittest +from unittest.mock import MagicMock, patch + +import torch +from torch._inductor.test_case import TestCase + + +try: + import cutlass # noqa: F401 + import cutlass.cute as cute # noqa: F401 + + HAS_CUTLASS = True +except ImportError: + HAS_CUTLASS = False + +if HAS_CUTLASS: + from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel + from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate + from torch._inductor.select_algorithm import PartialRender + +CUTEDSL_ADD_TEMPLATE = r""" +{{gen_defines()}} + +@cute.kernel +def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + + thread_idx = bidx * bdim + tidx + m, n = gA.shape + + if thread_idx < m * n: + mi = thread_idx // n + ni = thread_idx % n + + if mi < m and ni < n: + gC[mi, ni] = gA[mi, ni] + gB[mi, ni] + +@cute.jit +def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream): + {{gen_defines()}} + m, n = mA.shape + total_threads = m * n + num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK + + kernel = {{kernel_name}}_kernel(mA, mB, mC) + kernel.launch( + grid=[num_blocks, 1, 1], + block=[THREADS_PER_BLOCK, 1, 1], + stream=stream + ) + +{{def_kernel("input_a", "input_b", "output_c")}} + cute_a = from_dlpack(input_a) + cute_b = from_dlpack(input_b) + cute_c = from_dlpack(output_c) + + {{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream)) + return output_c +""" + + +@unittest.skipUnless(HAS_CUTLASS, "requires cutlass") +class TestCuteDSLTemplate(TestCase): + """Test cases for CuteDSL template functionality.""" + + def test_gen_imports(self): + kernel = CuteDSLTemplateKernel( + kernel_name="test_kernel", + input_nodes=[], + output_node=None, + ) + + imports = kernel.gen_imports() + + self.assertIn("import torch", imports) + self.assertIn("import cutlass", imports) + self.assertIn("import cutlass.cute as cute", imports) + self.assertIn("from cutlass.cute.runtime import from_dlpack", imports) + self.assertIsInstance(imports, str) + + lines = imports.strip().split("\n") + self.assertEqual(len(lines), 5) + + def test_render_includes_imports(self): + template_source = """@cute.kernel +def {{kernel_name}}_kernel(): + pass + +{{def_kernel("input", "output")}} + return output""" + + mock_template = MagicMock() + mock_template.render = MagicMock(return_value=template_source) + + kernel = CuteDSLTemplateKernel( + kernel_name="test_kernel", + input_nodes=[], + output_node=None, + ) + + result = kernel.render(mock_template) + self.assertIsInstance(result, PartialRender) + + rendered_code = result._code + + # The imports might have leading whitespace, so strip it + rendered_code_stripped = rendered_code.lstrip() + + self.assertTrue( + rendered_code_stripped.startswith("import torch"), + f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}", + ) + self.assertIn("import cutlass", rendered_code) + self.assertIn("import cutlass.cute as cute", rendered_code) + self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code) + self.assertIn("@cute.kernel", rendered_code) + + def test_template_env_contains_hooks(self): + kernel = CuteDSLTemplateKernel( + kernel_name="test_kernel", + input_nodes=[], + output_node=None, + ) + + captured_env = {} + + def mock_render(**kwargs): + captured_env.update(kwargs) + return "rendered" + + mock_template = MagicMock() + mock_template.render = mock_render + + kernel.render(mock_template) + + self.assertIn("def_kernel", captured_env) + self.assertIn("kernel_name", captured_env) + self.assertTrue(callable(captured_env["def_kernel"])) + + def test_multiple_templates_unique_names(self): + # Clean registry first + test_name = f"unique_test_{id(self)}" + if test_name in CuteDSLTemplate.all_templates: + del CuteDSLTemplate.all_templates[test_name] + + _ = CuteDSLTemplate( + name=test_name, + source="template1", + ) + + with self.assertRaises(AssertionError): + _ = CuteDSLTemplate( + name=test_name, + source="template2", + ) + + def test_indented_buffer_usage(self): + kernel = CuteDSLTemplateKernel( + kernel_name="test_kernel", + input_nodes=[], + output_node=None, + ) + + imports = kernel.gen_imports() + + lines = imports.strip().split("\n") + for line in lines: + if line: + self.assertFalse( + line.startswith(" "), f"Line should not be indented: '{line}'" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_cutedsl_add_e2e(self): + """End-to-end test with CuteDSL template including code generation verification.""" + from torch._inductor.ir import TensorBox + from torch._inductor.lowering import lowerings + from torch._inductor.utils import run_and_get_code + + template = CuteDSLTemplate( + name="test_add_e2e", + source=CUTEDSL_ADD_TEMPLATE, + ) + + def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox: + choices = [] + error = template.maybe_append_choice( + choices, + input_nodes=[a, b], + layout=a.get_layout(), + THREADS_PER_BLOCK=256, + ) + + if error or not choices: + default_lowering = lowerings[torch.ops.aten.add.Tensor] + return default_lowering(a, b) + + # Use the single choice directly (no autotuning) + return choices[0].output_node() + + with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}): + # Test function + def test_add(x, y): + return x + y + + device = "cuda" + x = torch.randn(128, 4, device=device, dtype=torch.float32) + y = torch.randn(128, 4, device=device, dtype=torch.float32) + + # Compile and get generated code + compiled_fn = torch.compile(test_add, backend="inductor") + result, (code,) = run_and_get_code(compiled_fn, x, y) + + # Verify CuteDSL code is present + self.assertIn( + "cute", code.lower(), "CuteDSL code should be in generated code" + ) + # Verify parameter generation worked + self.assertIn( + "THREADS_PER_BLOCK", code, "Parameter should be in generated code" + ) + + # Verify correctness + expected = x + y + self.assertTrue(torch.allclose(result, expected, atol=1e-5)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_cutedsl_add_e2e_autotune(self): + """E2E test with multiple CuteDSL template variants for autotuning.""" + from torch._inductor.ir import TensorBox + from torch._inductor.lowering import lowerings + from torch._inductor.select_algorithm import autotune_select_algorithm + + template = CuteDSLTemplate( + name="test_add_autotune", + source=CUTEDSL_ADD_TEMPLATE, + ) + + def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox: + choices = [] + + # Add multiple variants with different thread counts for autotuning + thread_variants = [128, 256, 512] + for threads in thread_variants: + error = template.maybe_append_choice( + choices, + input_nodes=[a, b], + layout=a.get_layout(), + THREADS_PER_BLOCK=threads, + ) + if error: + # Skip this variant if it fails + continue + + if not choices: + default_lowering = lowerings[torch.ops.aten.add.Tensor] + return default_lowering(a, b) + + # Use autotuning to select the best variant + return autotune_select_algorithm( + "cutedsl_add_autotune", + choices, + [a, b], + a.get_layout(), + ) + + with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}): + # Test function + def test_add(x, y): + return x + y + + device = "cuda" + x = torch.randn(128, 128, device=device, dtype=torch.float32) + y = torch.randn(128, 128, device=device, dtype=torch.float32) + + # Compile and run + compiled_fn = torch.compile(test_add, backend="inductor") + result = compiled_fn(x, y) + + # Verify correctness + expected = x + y + self.assertTrue(torch.allclose(result, expected, atol=1e-5)) + + def test_gen_defines(self): + """Test that gen_defines correctly generates CuteDSL parameter definitions.""" + kernel = CuteDSLTemplateKernel( + kernel_name="test_kernel", + input_nodes=[], + output_node=None, + ) + + # Test integer parameters + params = kernel.gen_defines( + THREADS_PER_BLOCK=256, + BLOCK_SIZE=128, + ENABLE_FEATURE=True, + ) + + expected_lines = [ + "THREADS_PER_BLOCK: cutlass.Constexpr = 256", + "BLOCK_SIZE: cutlass.Constexpr = 128", + "ENABLE_FEATURE: cutlass.Constexpr = True", + ] + + for expected_line in expected_lines: + self.assertIn(expected_line, params) + + # Test float parameters + params_float = kernel.gen_defines(SCALE_FACTOR=1.5) + self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + run_tests() diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index b2383830692..09bf4b1c9e2 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -569,6 +569,45 @@ class AsyncCompile: ) return LambdaFuture(get_result) + def cutedsl(self, kernel_name: str, source_code: str): + """ + Compile CuteDSL (CUTLASS Python DSL) kernels. + + Args: + kernel_name: Name of the kernel to be defined + source_code: Source code of the CuteDSL kernel, as a string + + Note: + CuteDSL currently requires source files to do its compilation, there we + use the PyCodeCache to write the source code to a file and load it. + """ + from torch._inductor.codegen.cutedsl.cutedsl_kernel import ( + CuteDSLKernelWrapper, + MAIN_SUFFIX, + ) + + kernel_code_log.info("CuteDSL Kernel:\n%s", source_code) + + def task(): + key, path = torch._inductor.codecache.PyCodeCache.write(source_code) + mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) + + # Find our special entry point named function + main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}" + ) + + return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path) + + if get_compile_threads() <= 1: + return task() + else: + future = self.submit(task) + return LambdaFuture(lambda: future.result()) + def wait(self, scope: dict[str, Any]) -> None: if get_compile_threads() > 1: with dynamo_timed( diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index d662b787d64..a504b54f132 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -44,7 +44,7 @@ from torch.utils._ordered_set import OrderedSet if TYPE_CHECKING: from types import ModuleType - from torch._inductor.select_algorithm import TritonTemplateCaller + from torch._inductor.select_algorithm import PartialRender, TritonTemplateCaller from . import config from .runtime.benchmarking import benchmarker @@ -876,6 +876,55 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest): return f"{self.kernel_name=}" +class CuteDSLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + """Benchmark request for CuteDSL (CUTLASS Python DSL) kernels.""" + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: tuple[Any, ...], + source_code: PartialRender, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + + finalized_code = source_code.finalize_all() + self.module_cache_key, self.module_path = PyCodeCache.write(finalized_code) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + """ + Create a function to run the CuteDSL kernel with the given input and output tensors. + Similar to TritonBenchmarkRequest.make_run_fn but for CuteDSL kernels. + """ + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + + # Logic replicated async_compile + from .codegen.cutedsl.cutedsl_kernel import MAIN_SUFFIX + + main_func_name = f"{self.kernel_name}_{MAIN_SUFFIX}" + + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}" + ) + + kernel_func = getattr(mod, main_func_name) + + def run_kernel(): + device_interface = get_interface_for_device("cuda") + stream = device_interface.get_raw_stream(out.device.index) + return kernel_func(*input_tensors, out, stream=stream) + + return run_kernel + + def cleanup_run_fn(self) -> None: + """Clean up any resources used by the kernel.""" + + @functools.cache def get_tuning_process_pool() -> TuningProcessPool: pool = TuningProcessPool() diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 0aee8760282..cb497284d52 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -11,6 +11,7 @@ from ..scheduler import ( SchedulerNode, ) from .cuda.cuda_cpp_scheduling import CUDACPPScheduling +from .cutedsl.cutedsl_scheduling import CuteDSLScheduling from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling from .triton import TritonScheduling @@ -44,6 +45,7 @@ class CUDACombinedScheduling(BaseScheduling): self._triton_scheduling = TritonScheduling(scheduler) self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) + self._cutedsl_scheduling = CuteDSLScheduling(scheduler) def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: return self._triton_scheduling.get_backend_features(device) @@ -53,6 +55,8 @@ class CUDACombinedScheduling(BaseScheduling): return self._cuda_cpp_scheduling if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): return self._rocm_cpp_scheduling + if self._cutedsl_scheduling.is_cutedsl_template(node): + return self._cutedsl_scheduling return self._triton_scheduling def can_fuse_vertical( @@ -64,6 +68,11 @@ class CUDACombinedScheduling(BaseScheduling): node1 ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2): return False + # CuteDSL doesn't support vertical fusion currently + elif self._cutedsl_scheduling.is_cutedsl_template( + node1 + ) or self._cutedsl_scheduling.is_cutedsl_template(node2): + return False return self._triton_scheduling.can_fuse_vertical(node1, node2) def can_fuse_horizontal( @@ -74,6 +83,10 @@ class CUDACombinedScheduling(BaseScheduling): return self._cuda_cpp_scheduling.can_fuse_horizontal( node1, node2 ) # always False at the moment + if self._cutedsl_scheduling.is_cutedsl_template(node): + return self._cutedsl_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment return self._triton_scheduling.can_fuse_horizontal(node1, node2) def group_fn( @@ -98,6 +111,13 @@ class CUDACombinedScheduling(BaseScheduling): return self._rocm_cpp_scheduling.codegen_template( template_node, epilogue_nodes, prologue_nodes ) + elif self._cutedsl_scheduling.is_cutedsl_template(template_node): + # TODO remove this when we add epilogue support + assert not epilogue_nodes + assert not prologue_nodes + return self._cutedsl_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) else: return self._triton_scheduling.codegen_template( template_node, epilogue_nodes, prologue_nodes diff --git a/torch/_inductor/codegen/cutedsl/README.md b/torch/_inductor/codegen/cutedsl/README.md new file mode 100644 index 00000000000..3b0deedafc3 --- /dev/null +++ b/torch/_inductor/codegen/cutedsl/README.md @@ -0,0 +1,101 @@ +# CuteDSL Template System + +## Quick Start + +Writing a CuteDSL template: + +```python +from torch._inductor.codegen.cutedsl import CuteDSLTemplate + +template_source = """ +@cute.kernel +def {{kernel_name}}_kernel(A, B, C): + # Your CUTLASS kernel logic here + pass + +{{def_kernel("A", "B", "C")}} + # Call the kernel + {{kernel_name}}_kernel(A, B, C) + return C +""" + +my_template = CuteDSLTemplate( + name="my_gemm", + source=template_source, +) +``` + +## Architecture + +- **[CuteDSLTemplate](cutedsl_template.py#L39)**: Template definition and registration. Generates ChoiceCallers for autotuning. +- **[CuteDSLTemplateKernel](cutedsl_kernel.py#L61)**: Handles code generation, provides template hooks (`def_kernel`), manages args. +- **[CuteDSLScheduling](cutedsl_scheduling.py#L28)**: Integrates with Inductor's scheduler, handles kernel compilation via [`async_compile.cutedsl()`](../../async_compile.py#L756). +- **[CuteDSLTemplateBuffer](../../ir.py)**: IR node representing a CuteDSL template operation in the graph. + +### Compilation Process + +CuteDSL requires source files for compilation (cannot compile from strings directly). The process: + +1. **[CuteDSLScheduling](cutedsl_scheduling.py#L59)** generates the kernel code string and calls [`async_compile.cutedsl()`](../../async_compile.py#L756) +2. **[async_compile.cutedsl()](../../async_compile.py#L756)** uses [`PyCodeCache.write()`](../../codecache.py) to write source to a temporary `.py` file +3. **[PyCodeCache](../../codecache.py)** loads the module from disk, enabling CUTLASS compilation +4. The compiled kernel is wrapped in **[CuteDSLKernelWrapper](cutedsl_kernel.py#L22)** to provide a `.run()` interface +5. The generated Python file is cached via PyCodeCache, but CUTLASS compilation runs every time (no kernel-level caching yet) + +**Debug tip**: Use `TORCH_LOGS="kernel_code"` to see the generated kernel source and file path during compilation. + +## Writing Templates + +Templates use Jinja2 syntax with these available hooks: + +- `{{kernel_name}}` - Unique kernel identifier +- `{{def_kernel(args...)}}` - Generates kernel function signature and argument handling +- `{{input_nodes}}` - List of input buffers +- `{{output_node}}` - Output buffer +- `{{gen_defines()}}` - Generates autotunable parameter definitions with proper CuteDSL typing + +## Autotunable Parameters + +CuteDSL templates support autotunable parameters similar to Triton's `tl.constexpr` system: + +```python +template_source = r""" +{{gen_defines()}} + +@cute.kernel +def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): + threads_per_block = THREADS_PER_BLOCK # Uses autotuned value + block_size = BLOCK_SIZE + # ... kernel implementation +""" + +# Pass parameters when generating template choices +template.maybe_append_choice( + choices, + input_nodes=[a, b], + layout=layout, + THREADS_PER_BLOCK=256, # cutlass.Constexpr = 256 + BLOCK_SIZE=128, # cutlass.Constexpr = 128 + SCALE_FACTOR=1.5, # cutlass.Constexpr = 1.5 +) +``` + +Templates must: +1. Define a `@cute.kernel` decorated function +2. Use `{{def_kernel()}}` to create the entry point +3. Return the output tensor +4. Use `{{gen_defines()}}` for autotunable parameters + +See [test_cutedsl_template.py](../../../../test/inductor/test_cutedsl_template.py) for complete examples. + +## Current Limitations / TODOs + +- **No fusion support**: `can_fuse_vertical` and `can_fuse_horizontal` return False +- **Subgraph management**: Bodies and masks not fully implemented +- **File-based compilation**: Requires writing to disk (uses PyCodeCache) +- **Missing epilogue/prologue**: No support for fused operations yet +- **Fixed kernel suffix**: Uses hardcoded "_main" suffix +- **No CUTLASS kernel caching**: Only PyCodeCache works; CUTLASS compilation runs every time (major perf issue) + + +Note: Requires CUTLASS Python package (`pip install nvidia-cutlass`) \ No newline at end of file diff --git a/torch/_inductor/codegen/cutedsl/__init__.py b/torch/_inductor/codegen/cutedsl/__init__.py new file mode 100644 index 00000000000..f12fa963fd6 --- /dev/null +++ b/torch/_inductor/codegen/cutedsl/__init__.py @@ -0,0 +1,8 @@ +# mypy: allow-untyped-defs +from .cutedsl_template import CuteDSLTemplate, CuteDSLTemplateCaller + + +__all__ = [ + "CuteDSLTemplate", + "CuteDSLTemplateCaller", +] diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py new file mode 100644 index 00000000000..ca6af6690e6 --- /dev/null +++ b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -0,0 +1,222 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import logging +from typing import Any, Callable, Optional + +import torch +from torch._inductor.codegen.common import IndentedBuffer, Kernel +from torch._inductor.ir import Buffer +from torch._inductor.select_algorithm import PartialRender +from torch._inductor.utils import OrderedSet +from torch._inductor.virtualized import V + + +# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this +MAIN_SUFFIX = "main" + +log = logging.getLogger(__name__) +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class CuteDSLKernelWrapper: + """Wrapper to provide .run() interface for CuteDSL kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("CuteDSL kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the CuteDSL kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +@dataclasses.dataclass +class CuteDSLSubgraphInfo: + """Minimal subgraph info for CuteDSL kernels.""" + + body: IndentedBuffer + template_mask: Optional[str] = None + template_out: Optional[str] = None + + def to_dict(self): + return { + field.name: getattr(self, field.name) for field in dataclasses.fields(self) + } + + +class CuteDSLTemplateKernel(Kernel): + """ + Template kernel implementation for CuteDSL (CUTLASS Python DSL). + Handles code generation and argument management for CuteDSL CUDA kernels. + Provides CuteDSL-specific functionality for tensor conversion and kernel configuration. + """ + + def __init__( + self, + kernel_name: str, + input_nodes: list[Buffer], + output_node: Buffer, + ) -> None: + # Call parent Kernel constructor + super().__init__() + self.kernel_name = kernel_name + self.input_nodes = input_nodes + self.output_node = output_node + + # TODO Subgraph management for template processing + self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {} + + # Template attributes + self.body: IndentedBuffer = IndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None + self.template_indices: Optional[list[Any]] = None + self.render_hooks: dict[str, Any] = {} + + # TODO Additional attributes needed by template system + self.prologue_fused_inputs: OrderedSet[str] = OrderedSet() + self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet() + self.named_input_nodes: dict[str, Buffer] = {} + + # Create named input nodes mapping + for i, input_node in enumerate(input_nodes): + node_name = getattr(input_node, "name", f"input_{i}") + self.named_input_nodes[node_name] = input_node + + def gen_imports(self) -> str: + """Generate common imports for CuteDSL templates.""" + imports = IndentedBuffer() + imports.splice( + """ + import torch + import cutlass + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + import cuda.bindings.driver as cuda + """ + ) + return imports.getvalue() + + def gen_defines(self, **kwargs) -> str: + """Generate CuteDSL parameter definitions from kwargs, similar to Triton's gen_defines.""" + params = IndentedBuffer() + for name, val in kwargs.items(): + params.writeline(f"{name}: cutlass.Constexpr = {val}") + return params.getvalue() + + def render(self, template, **kwargs): + """Render the kernel using the template, returning PartialRender object with hooks.""" + # Available {{}} hooks for jinja rendering + template_env = { + "def_kernel": self.def_kernel, + "gen_defines": lambda: self.gen_defines(**kwargs), + } + + # Render the template with the environment and provided kwargs + rendered_code = template.render( + kernel_name=self.kernel_name, + input_nodes=self.input_nodes, + output_node=self.output_node, + **template_env, + **kwargs, + ) + + # Always prepend the common imports + imports = self.gen_imports() + full_code = imports + rendered_code + + return PartialRender(full_code, self.render_hooks) + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + """Set the active subgraph body for template processing.""" + assert all( + hasattr(self, field.name) + for field in dataclasses.fields(CuteDSLSubgraphInfo) + ) + old_state = { + key.name: getattr(self, key.name) + for key in dataclasses.fields(CuteDSLSubgraphInfo) + } + + if body_name not in self.subgraph_bodies: + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + body=IndentedBuffer(), + template_mask=None, + template_out=None, + ) + + subgraph = self.subgraph_bodies[body_name] + for key, value in subgraph.to_dict().items(): + setattr(self, key, value) + + try: + yield + finally: + # Save current state back to subgraph + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + **{ + key.name: getattr(self, key.name) + for key in dataclasses.fields(CuteDSLSubgraphInfo) + } + ) + # Restore old state + for key, value in old_state.items(): + setattr(self, key, value) + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str): + """Create a new subgraph body for template processing.""" + assert body_name not in self.subgraph_bodies, ( + f"Subgraph body '{body_name}' already exists" + ) + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + body=IndentedBuffer(), + template_mask=None, + template_out=None, + ) + with self.set_subgraph_body(body_name): + yield + + def def_kernel(self, *argnames): + """Define kernel function signature for CuteDSL templates.""" + # Populate all the kernel args + for i, input_node in enumerate(self.input_nodes): + self.args.input(input_node.get_name()) + + if self.output_node: + self.args.output(self.output_node.get_name()) + + def hook(): + code = IndentedBuffer() + code.writeline(f"# Kernel function signature: {self.kernel_name}") + params = list(argnames) + ["stream"] + code.writeline( + f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(params)}):" + ) + return code.getvalue() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def call_kernel(self, name: str, node=None): + """Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel.""" + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + # TODO triton should really be swapped w/ `python` + wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types) diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py b/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py new file mode 100644 index 00000000000..427b6fe5f1d --- /dev/null +++ b/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ... import config +from ...codecache import code_hash, get_path +from ...ir import CuteDSLTemplateBuffer +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, +) +from ...select_algorithm import PartialRender +from ...utils import get_fused_kernel_name, get_kernel_metadata +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class CuteDSLScheduling(BaseScheduling): + """ + Scheduling implementation for CuteDSL (CUTLASS Python DSL) kernels. + This class is intended to be used in combination with other schedulers, + and delegated to by CUDACombinedScheduling. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + @staticmethod + def is_cutedsl_template(node: BaseSchedulerNode) -> bool: + """Check if a node is a CuteDSL template.""" + return isinstance(node, SchedulerNode) and isinstance( + node.node, CuteDSLTemplateBuffer + ) + + def is_cutedsl_fused_template(self, node: BaseSchedulerNode) -> bool: + """Check if a node is a fused CuteDSL template.""" + return isinstance(node, FusedSchedulerNode) and self.is_cutedsl_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + TODO CuteDSL doesn't support vertical fusion yet. + This could be extended in the future for epilogue fusion. + """ + return False + + def define_kernel(self, src_code_str: str, node_schedule) -> str: + """Produce the kernel string + Args: + src_code_str: The finalized kernel code string + node_schedule: List of nodes in the schedule + + Note: + This is a little weird since async_compile.cutedsl() has to write the string to + a file in order to cute compile it. Feels bad to have two... + """ + wrapper = V.graph.wrapper_code + + # Use the string as the key for caching + if src_code_str in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code_str] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + kernel_hash = hashlib.sha256(src_code_str.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"cutedsl_{kernel_hash}" + else: + kernel_name = f"cutedsl_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code_str] = kernel_name + src_code_str = src_code_str.replace( + str(Placeholder.KERNEL_NAME), kernel_name + ) + + _, _, kernel_path = get_path(code_hash(src_code_str), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.cutedsl({kernel_name!r}, r'''") + compile_wrapper.splice(src_code_str, strip=True) + compile_wrapper.writeline("''')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CuteDSL template. Currently doesn't support fusion. + """ + assert self.is_cutedsl_template(template_node), ( + "Template node passed to CuteDSLScheduling.codegen_template must be a " + "SchedulerNode that wraps a CuteDSLTemplateBuffer" + ) + # TODO remove when supported + assert not epilogue_nodes, "CuteDSL doesn't support epilogue fusion yet" + assert not prologue_nodes, "CuteDSL doesn't support prologue fusion yet" + + template_node = cast(SchedulerNode, template_node) + ctb: CuteDSLTemplateBuffer = cast(CuteDSLTemplateBuffer, template_node.node) + + kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] + template_node.mark_run() + src_code = render() + # Finalize PartialRender if needed + if isinstance(src_code, PartialRender): + src_code_str = src_code.finalize_all() + else: + src_code_str = src_code + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code_str, node_schedule) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_template.py b/torch/_inductor/codegen/cutedsl/cutedsl_template.py new file mode 100644 index 00000000000..1ce0528348c --- /dev/null +++ b/torch/_inductor/codegen/cutedsl/cutedsl_template.py @@ -0,0 +1,178 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from typing import Any, Optional, Union + +from torch._inductor.ir import ShapeAsConstantBuffer +from torch._inductor.utils import Placeholder +from torch._logging import getArtifactLogger + +from ...autotune_process import CuteDSLBenchmarkRequest, TensorMeta +from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, Layout, TensorBox +from ..common import KernelTemplate +from .cutedsl_kernel import CuteDSLTemplateKernel + + +log = getArtifactLogger(__name__, "output_code") + + +class CuteDSLTemplate(KernelTemplate): + """Template for generating CuteDSL (CUTLASS Python DSL) kernels.""" + + kernel_type: type[Any] = CuteDSLTemplateKernel + index_counter = itertools.count() + all_templates: dict[str, "CuteDSLTemplate"] = {} + + def __init__( + self, + name: str, + source: str, + subgraph_fn: Optional[Any] = None, + mask_fn: Optional[Any] = None, + ) -> None: + super().__init__(name) + self.source = source + self.subgraph_fn = subgraph_fn + self.mask_fn = mask_fn + self.template = CuteDSLTemplate._template_from_string(source) + assert name not in self.all_templates, f"duplicate template name, {name}" + CuteDSLTemplate.all_templates[name] = self + + @staticmethod + @functools.lru_cache(None) + def _template_from_string(source: str) -> Any: + return KernelTemplate._template_from_string(source) + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + """ + try: + choices.append(self.generate(**kwargs)) + return None + except NotImplementedError as e: + log.debug("CuteDSL template choice generation failed: %s", e) + return e + except Exception as e: + log.debug("CuteDSL template choice generation error: %s", e) + return NotImplementedError(f"CuteDSL template failed: {e}") + + def generate(self, **kwargs: Any) -> ChoiceCaller: + """Generate the CuteDSL kernel caller.""" + input_nodes = kwargs.pop("input_nodes") + layout = kwargs.pop("layout") + + kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}" + + if self.template is None: + raise RuntimeError("Template compilation failed (Jinja2 required)") + + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + + kernel = self.kernel_type( + kernel_name=kernel_name, + input_nodes=input_nodes, + output_node=self.output_node, + ) + + code = kernel.render(self.template, **kwargs) + + log.debug("Generated CuteDSL Code:\n%s", code) + + bmreq = CuteDSLBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=tuple(), + source_code=code, + ) + + def make_kernel_render(out_node, hint_override: Optional[int] = None): + render_kernel = self.kernel_type( + kernel_name=str(Placeholder.KERNEL_NAME), + input_nodes=input_nodes, + output_node=out_node, + ) + + def render(): + return render_kernel.render(self.template, **kwargs) + + return render_kernel, render + + return CuteDSLTemplateCaller( + name=kernel_name, + input_nodes=input_nodes, + layout=layout, + make_kernel_render=make_kernel_render, + bmreq=bmreq, + template=self, + ) + + +class CuteDSLTemplateCaller(ChoiceCaller): + """Caller for CuteDSL templates that integrates with the autotuning system.""" + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Any, + bmreq: CuteDSLBenchmarkRequest, + template: "CuteDSLTemplate", + ): + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + description=f"CuteDSL template {name}", + ) + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + + def __str__(self) -> str: + return f"CuteDSLTemplateCaller({self.name})" + + def benchmark(self, *args, out) -> float: + """Benchmark the kernel execution.""" + return self.bmreq.benchmark(*args, out=out) + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + """Create the output node for this template choice.""" + return TensorBox.create( + CuteDSLTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + ) + ) + + def call_name(self) -> str: + """Return the kernel call name.""" + return self.name + + def to_callable(self) -> Any: + """Return callable that can execute this kernel.""" + return self.make_kernel_render + + def hash_key(self) -> str: + """Return unique hash key for this choice.""" + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def info_dict(self) -> dict[str, Any]: + """Return information about this kernel.""" + return { + "name": self.name, + "backend": "CuteDSL", + "template": self.template.name, + } diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0490023584b..e1e2ef23eeb 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5132,6 +5132,37 @@ class CppTemplateBuffer(TemplateBuffer): return super().get_layout() +class CuteDSLTemplateBuffer(TemplateBuffer): + """ + Buffer for CuteDSL (CUTLASS Python DSL) template kernels. + Similar to other template buffers but specialized for CuteDSL operations. + """ + + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[_P, _T], + template: Any, + mutated_inputs: Optional[Iterable[IRNode]] = None, + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.mutated_inputs = mutated_inputs + self.outputs: list[Buffer] = [self] + + if mutated_inputs is not None: + assert isinstance(self.inputs[0], IRNode), type(self.inputs[0]) + device = self.inputs[0].get_device() + self.outputs += [ + MutationOutput(NoneLayout(device=device), buf, self) + for buf in mutated_inputs + ] + + def get_outputs(self) -> list[Buffer]: + return self.outputs + + def is_node_sequence( nodes: Sequence[Union[IRNode, Sequence[IRNode]]], ) -> TypeIs[Sequence[IRNode]]: