[Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458)

Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942.

# Feature

This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.

![wrapper_ir](https://github.com/user-attachments/assets/a61db21b-caf3-45d2-bfdb-91066ae4ba6b)

The IR currently supports the following ops:
- All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.)
- Reinterpret views (`ReinterpretLine`)
- Kernel definitions (`KernelDefinitionLine`)
- Calls to defined kernels (`KernelCallLine`)
- Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`)
- Ops with multiple outputs (`MultiOutputLine`)
- Tensor cleanup at the end of a graph (`FreeLine`)
- Leaving comments in code (`CommentLine`)

There are two main motivations for this refactor:
1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`.
2. This design will hopefully promote stronger modularity and encapsulation.
   a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends.
   b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code.

# Implementation details

The implementation mainly consists of separating direct C++/Python codegen into two phases:
 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do.
 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code.

The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to  `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros.

# Test plan

Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage.

The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage.

# Future directions

These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files:
 - User-defined Triton kernels.
 - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.)
 -  Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`.
 -  Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants.

These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes.

One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor.

Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458
Approved by: https://github.com/eellison
This commit is contained in:
Blaine Burton Rister 2025-04-12 01:15:15 +00:00 committed by PyTorch MergeBot
parent 575f348965
commit fe7f425de7
10 changed files with 375 additions and 120 deletions

View File

@ -72,7 +72,7 @@ class TestMemoryPlanning(TestCase):
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_4, int_array_5, &tmp_tensor_handle_0)"
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_2, int_array_3, &tmp_tensor_handle_0)"
).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_0);").check(
"auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_1);"
).run(
@ -97,17 +97,17 @@ class TestMemoryPlanning(TestCase):
)
FileCheck().check(
"int64_t int_array_2[] = {24L + align(12L*s77), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next(
"int64_t int_array_0[] = {24L + align(12L*s77), };"
).check_next("int64_t int_array_1[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;"
).check_next(
"aoti_torch_empty_strided(1, int_array_2, int_array_3,"
"aoti_torch_empty_strided(1, int_array_0, int_array_1,"
).check_next(
"RAIIAtenTensorHandle pool1(pool1_handle);"
).check_next(
"int64_t int_array_4[] = {s77, 3L};"
"int64_t int_array_2[] = {s77, 3L};"
).check_next(
"int64_t int_array_5[] = {3L, 1L};"
"int64_t int_array_3[] = {3L, 1L};"
).check_next(
"AtenTensorHandle tmp_tensor_handle_0;"
).check_next(

View File

@ -102,7 +102,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}"
)
def generate_kernel_call(
def _generate_kernel_call_helper(
self,
kernel_name: str,
call_args,
@ -113,6 +113,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
raw_keys=None,
raw_args=None,
triton_meta=None,
graph_name="",
original_fxnode_name=None,
):
"""
@ -908,7 +909,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.prefix.splice(aot_mode_decls)
self.prefix.splice(prior)
def define_kernel(
def _define_kernel_helper(
self,
kernel_name: str,
kernel_body: str,
@ -1160,7 +1161,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
if not is_inplace:
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
def generate_extern_kernel_alloc(self, extern_kernel, args):
def _generate_extern_kernel_alloc_helper(self, extern_kernel, args):
if getattr(extern_kernel, "outputs", None):
# ir.ExternKernelAlloc may have outputs if it returns a tuple
self.generate_c_shim_fallback_kernel(extern_kernel, args)
@ -1209,10 +1210,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
for raii_handle in output_raii_handles:
self.writeline(raii_handle)
def generate_fallback_kernel(self, fallback_kernel, args):
self.generate_c_shim_fallback_kernel(fallback_kernel, args)
def generate_extern_kernel_out(
def _generate_extern_kernel_out_helper(
self,
kernel: str,
out: str,
@ -1652,7 +1650,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));"
)
def codegen_multi_output(self, name, value):
def codegen_multi_output(self, node: ir.MultiOutput):
# in the abi_compatible mode, outputs are retrieved by passing
# output pointers, so we skip its codegen here.
pass

View File

@ -87,7 +87,22 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
numel = buf.get_numel()
self.prefix.writeline(f"assert_numel({name}, {numel});")
def generate_kernel_call(
def generate_extern_kernel_alloc(self, *args, **kwargs):
# Disable stack allocation for extern kernels.
self.allow_stack_allocation = False
super().generate_extern_kernel_alloc(*args, **kwargs)
def generate_extern_kernel_out(self, *args, **kwargs):
# Disable stack allocation for extern kernels.
self.allow_stack_allocation = False
super().generate_extern_kernel_out(*args, **kwargs)
def generate_fallback_kernel(self, *args, **kwargs):
# Disable stack allocation for extern kernels.
self.allow_stack_allocation = False
super().generate_fallback_kernel(*args, **kwargs)
def _generate_kernel_call_helper(
self,
kernel_name: str,
call_args,
@ -98,6 +113,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
raw_keys=None,
raw_args=None,
triton_meta=None,
graph_name="",
original_fxnode_name=None,
):
"""

View File

@ -240,7 +240,7 @@ class CppWrapperGpu(CppWrapperCpu):
def write_tma_descriptor_helpers_once(self):
self.header.splice(self.device_codegen.tma_descriptor_helpers())
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
name = f"stream{device_idx}"
self.writeline(
maybe_hipify_code_wrapper(
@ -294,7 +294,7 @@ class CppWrapperGpu(CppWrapperCpu):
super().codegen_inputs()
def define_kernel(
def _define_kernel_helper(
self,
kernel_name: str,
kernel_body: str,
@ -306,11 +306,11 @@ class CppWrapperGpu(CppWrapperCpu):
self._kernel_name_to_body[kernel_name] = kernel_body
if config.triton.autotune_at_compile_time:
# Call PythonWrapperCodegen to create the autotune code block
PythonWrapperCodegen.define_kernel(
PythonWrapperCodegen._define_kernel_helper(
self, kernel_name, kernel_body, metadata, gpu, cpp_definition
)
else:
return CppWrapperCpu.define_kernel(
return CppWrapperCpu._define_kernel_helper(
self, kernel_name, kernel_body, metadata, gpu, cpp_definition
)
@ -445,7 +445,7 @@ class CppWrapperGpu(CppWrapperCpu):
return ", ".join(new_args)
def generate_kernel_call(
def _generate_kernel_call_helper(
self,
kernel_name: str,
call_args,
@ -456,6 +456,7 @@ class CppWrapperGpu(CppWrapperCpu):
raw_keys=None,
raw_args=None,
triton_meta=None,
graph_name="",
original_fxnode_name=None,
):
"""
@ -466,7 +467,7 @@ class CppWrapperGpu(CppWrapperCpu):
device = device or V.graph.get_current_device_or_throw()
if device.type == "cpu":
# Even in CppWrapperGpu, we may see cpp kernels
return CppWrapperCpu.generate_kernel_call(
return CppWrapperCpu._generate_kernel_call_helper(
self,
kernel_name,
call_args,
@ -484,7 +485,7 @@ class CppWrapperGpu(CppWrapperCpu):
and kernel_name not in self.kernel_autotune_names
):
# Call PythonWrapperCodegen to create the autotune code block
PythonWrapperCodegen.generate_kernel_call(
PythonWrapperCodegen._generate_kernel_call_helper(
self,
kernel_name,
call_args,
@ -500,7 +501,7 @@ class CppWrapperGpu(CppWrapperCpu):
stream = (
"stream"
if V.graph.aot_mode
else self.write_get_raw_stream(device.index, V.graph)
else self.write_get_raw_stream(device.index, graph_name)
)
if triton:

View File

@ -5,7 +5,7 @@ import functools
import logging
import os
from enum import Enum
from typing import Optional
from typing import Callable, Optional
import torch
from torch import dtype as torch_dtype
@ -57,6 +57,7 @@ class DebugPrinterManager:
self,
debug_printer_level,
use_array_ref: bool,
writeline: Optional[Callable[..., None]] = None,
args_to_print_or_save: Optional[list[str]] = None,
kernel_name: str = "",
kernel=None,

View File

@ -1637,7 +1637,9 @@ class HalideKernel(SIMDKernel):
call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None]
current_device = V.graph.get_current_device_or_throw()
if current_device.type == "cuda":
stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph)
stream_name = wrapper.write_get_raw_stream(
current_device.index, V.graph.name
)
call_args.append(stream_name)
wrapper.generate_kernel_call(
name,

View File

@ -87,6 +87,7 @@ class MultiKernelState:
def __init__(self):
self.subkernel_to_kernel_name = {}
self.kernel_defs = IndentedBuffer()
def define_kernel(self, kernels):
"""
@ -116,7 +117,7 @@ class MultiKernelState:
# the second pass of cpp-wrapper.
return multi_kernel_name
buf = IndentedBuffer()
buf = self.kernel_defs
buf.writeline("")
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
@ -126,12 +127,10 @@ class MultiKernelState:
buf.writeline(f"{name},")
buf.writeline("])")
wrapper = V.graph.wrapper_code
if config.triton.autotune_at_compile_time:
wrapper.kernel_autotune_defs.splice(buf)
wrapper.src_to_kernel["\n".join(kernel_names)] = multi_kernel_name
else:
wrapper.header.splice(buf)
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
multi_kernel_name
)
return multi_kernel_name

View File

@ -4089,7 +4089,7 @@ class TritonScheduling(SIMDScheduling):
wrapper = V.graph.wrapper_code
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
if origins:
wrapper.writeline(origins)
wrapper.make_comment(origins)
if config.debug_fusion:
from torch._inductor.scheduler import (
@ -4107,7 +4107,7 @@ class TritonScheduling(SIMDScheduling):
for n in node_schedule
if isinstance(n, BaseSchedulerNode)
]
wrapper.writeline(
wrapper.make_comment(
f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
)

View File

@ -75,6 +75,8 @@ if TYPE_CHECKING:
from ..graph import GraphLowering
log = logging.getLogger(__name__)
pexpr = PythonPrinter().doprint
@ -362,6 +364,14 @@ class EnterSubgraphLine(WrapperLine):
code.do_indent()
@dataclasses.dataclass
class CommentLine(WrapperLine):
line: LineContext
def codegen(self, code: IndentedBuffer) -> None:
code.writeline(self.line)
@dataclasses.dataclass
class ExitSubgraphLine(WrapperLine):
wrapper: PythonWrapperCodegen
@ -415,6 +425,102 @@ class ExitDeviceContextManagerLine(WrapperLine):
code.do_unindent()
@dataclasses.dataclass
class ExternKernelAllocLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: ir.ExternKernelAlloc
def codegen(self, code: IndentedBuffer) -> None:
node = self.node
args = [*node.codegen_args(), *node.codegen_kwargs()]
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
@dataclasses.dataclass
class ExternKernelOutLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: ir.ExternKernelOut
def codegen(self, code: IndentedBuffer) -> None:
node = self.node
args = [*node.codegen_args(), *node.codegen_kwargs(skip_out=True)]
kernel_name = node.get_kernel_name()
if (
V.graph.cpp_wrapper
and node.cpp_kernel_name == "torch::inductor::_mm_plus_mm"
):
# For https://github.com/pytorch/pytorch/issues/128474
kernel_name = "aoti_torch__mm_plus_mm_out"
else:
kernel_name = node.get_kernel_name()
device = d.type if (d := node.get_device()) else V.graph.device_type
self.wrapper._generate_extern_kernel_out_helper(
kernel_name,
node.codegen_reference(),
node.output_view.codegen_reference() if node.output_view else None,
args,
device,
)
@dataclasses.dataclass
class FreeLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: Union[BufferLike, ir.TorchBindObject]
def codegen(self, code: IndentedBuffer) -> None:
assert self.node.get_name() not in V.graph.removed_buffers
code.writeline(self.wrapper.make_buffer_free(self.node))
@dataclasses.dataclass
class KernelCallLine(WrapperLine):
wrapper: PythonWrapperCodegen
kernel_name: str
call_args: tuple[Any, ...]
raw_keys: tuple[Any, ...]
raw_args: tuple[Any, ...]
arg_types: list[str]
triton: bool
triton_meta: dict[str, Any]
device: torch.device
graph_name: str
original_fxnode_name: str
def codegen(self, code: IndentedBuffer) -> None:
self.wrapper._generate_kernel_call_helper(
self.kernel_name,
self.call_args,
triton=self.triton,
arg_types=self.arg_types,
raw_keys=self.raw_keys,
raw_args=self.raw_args,
triton_meta=self.triton_meta,
device=self.device,
graph_name=self.graph_name,
original_fxnode_name=self.original_fxnode_name,
)
@dataclasses.dataclass
class KernelDefinitionLine(WrapperLine):
wrapper: PythonWrapperCodegen
kernel_name: str
kernel_body: str
metadata: Optional[str] = None
gpu: bool = True
cpp_definition: Optional[str] = None
def codegen(self, code: IndentedBuffer) -> None:
self.wrapper._define_kernel_helper(
self.kernel_name,
self.kernel_body,
metadata=self.metadata,
gpu=self.gpu,
cpp_definition=self.cpp_definition,
)
@dataclasses.dataclass
class MemoryPlanningLine(WrapperLine):
wrapper: PythonWrapperCodegen
@ -494,6 +600,23 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
code.writeline(self.wrapper.make_buffer_free(self.node))
@dataclasses.dataclass
class ReinterpretLine(MemoryPlanningLine):
node: BufferLike
reused_as: BufferLike
layout: ir.Layout
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
return self
def codegen(self, code: IndentedBuffer) -> None:
assert isinstance(self.layout, ir.NonOwningLayout)
assert isinstance(self.layout.view, ir.ReinterpretView)
self.wrapper.codegen_deferred_allocation(
self.reused_as.get_name(), self.layout.view
)
@dataclasses.dataclass
class ReuseLine(MemoryPlanningLine):
node: BufferLike
@ -598,7 +721,49 @@ class CommBufferFreeLine(CommBufferLine):
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
@dataclasses.dataclass
class MultiOutputLine(WrapperLine):
"""
Given a MultiOutputLayout buffer, indexes actual buffer(s) from the result.
"""
wrapper: PythonWrapperCodegen
result_name: str
arg_name: str
indices: Sequence[Any]
def codegen(self, code: IndentedBuffer) -> None:
def codegen_list_tuple_access(basename, indices): # type: ignore[no-untyped-def]
if len(indices) > 0:
itype, i = indices[0]
if issubclass(itype, list):
return codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
elif issubclass(itype, tuple):
# cpp wrapper code needs to use std::get<> to access a tuple
tuple_access = self.wrapper.codegen_tuple_access(
basename, self.result_name, str(i)
)
return codegen_list_tuple_access(tuple_access, indices[1:])
elif issubclass(itype, dict):
return codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
else:
raise AssertionError("non supported index type: ", itype)
else:
return basename
value = codegen_list_tuple_access(self.arg_name, self.indices)
code.writeline(
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
)
@dataclasses.dataclass
class OutputLine(WrapperLine):
buffers: tuple[BufferLike, ...]
BufferName = str
Line = Union[MemoryPlanningLine, LineContext]
class PythonWrapperCodegen(CodeGen):
@ -609,6 +774,9 @@ class PythonWrapperCodegen(CodeGen):
def __init__(self):
super().__init__()
self._names_iter: Iterator[int] = count()
self.args_to_buffers: dict[
str, Union[None, ir.TensorBox, ir.Buffer, ir.TorchBindObject]
] = {}
self.imports = IndentedBuffer()
self.header = IndentedBuffer()
self.prefix = IndentedBuffer()
@ -627,7 +795,7 @@ class PythonWrapperCodegen(CodeGen):
# pre-existing kernel for it
self.src_to_kernel: dict[str, str] = {}
self.kernel_numel_expr: OrderedSet[tuple[str, GraphLowering]] = OrderedSet()
self.lines: list[Union[MemoryPlanningLine, LineContext]] = []
self.lines: list[Line] = []
self.declare = ""
self.declare_maybe_reference = ""
self.ending = ""
@ -953,10 +1121,10 @@ class PythonWrapperCodegen(CodeGen):
if config.nan_asserts:
self.codegen_input_nan_asserts()
# this function (and below) takes a graph as input so
# this function (and below) takes the graph name as input so
# that stream caching happens per graph instance. this
# is important for nested subgraph codegening.
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
self.write_get_raw_stream_header_once()
name = f"stream{device_idx}"
if config.triton.autotune_at_compile_time:
@ -1039,10 +1207,16 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None:
return
def generate_fallback_kernel(self, fallback_kernel, args):
self.generate_extern_kernel_alloc(fallback_kernel, args)
def generate_fallback_kernel(self, node: ir.FallbackKernel):
self.writeline(ExternKernelAllocLine(self, node))
def generate_extern_kernel_alloc(self, extern_kernel, args):
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
node.codegen_comment(self)
self.writeline(ExternKernelAllocLine(self, node))
if isinstance(node.layout, ir.Layout):
node.codegen_size_asserts(self)
def _generate_extern_kernel_alloc_helper(self, extern_kernel, args):
# If it's a NoneLayout then the extern_kernel should essentially be
# treated as if it doesn't return anything
no_return = isinstance(extern_kernel.layout, ir.NoneLayout)
@ -1072,6 +1246,13 @@ class PythonWrapperCodegen(CodeGen):
)
def generate_extern_kernel_out(
self,
node: ir.ExternKernelOut,
) -> None:
node.codegen_comment(self)
self.writeline(ExternKernelOutLine(self, node))
def _generate_extern_kernel_out_helper(
self,
kernel: str,
out: str,
@ -1159,20 +1340,25 @@ class PythonWrapperCodegen(CodeGen):
else:
return 1
@contextlib.contextmanager
def set_writeline(self, new: Callable[..., None]) -> Iterator[Callable[..., None]]:
old = self.writeline
try:
self.writeline = new # type: ignore[method-assign]
yield new
finally:
self.writeline = old # type: ignore[method-assign]
def _write_multi_kernel_defs(self) -> None:
kernel_defs = self.multi_kernel_state.kernel_defs
if config.triton.autotune_at_compile_time:
self.kernel_autotune_defs.splice(kernel_defs)
else:
self.header.splice(kernel_defs)
def _generate(self, is_inference):
if config.profile_bandwidth:
self.write_triton_header_once()
result = IndentedBuffer()
result.splice(self.imports)
result.writeline("")
result.splice(self.header)
# We do not want the cpp header for intermediate const graph. Headers would be
# rendered by the main module instead.
if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
result = IndentedBuffer()
# Add subgraph definitions to the result
result.splice(self.subgraph_definitions)
with contextlib.ExitStack() as stack:
stack.enter_context(self.wrapper_call.indent())
@ -1190,11 +1376,16 @@ class PythonWrapperCodegen(CodeGen):
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
self.generate_reset_kernel_saved_flags()
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen(self.wrapper_call)
else:
self.wrapper_call.writeline(line)
# At this point, we shouldn't generate any new memory planning lines.
# Override writeline to point at the wrapper call, in case it gets called.
with self.set_writeline(self.wrapper_call.writeline):
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen(self.wrapper_call)
else:
self.wrapper_call.writeline(line)
self._write_multi_kernel_defs()
output_refs = self.get_output_refs()
self.mark_output_type()
@ -1217,6 +1408,18 @@ class PythonWrapperCodegen(CodeGen):
)
self.generate_return(output_refs)
# Assemble the final code from sections.
result = IndentedBuffer()
result.splice(self.imports)
result.writeline("")
result.splice(self.header)
# We do not want the cpp header for intermediate const graph. Headers would be
# rendered by the main module instead.
if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
result = IndentedBuffer()
# Add subgraph definitions to the result
result.splice(self.subgraph_definitions)
self.finalize_prefix()
result.splice(self.prefix)
@ -1454,8 +1657,10 @@ class PythonWrapperCodegen(CodeGen):
def codegen_device_copy(self, src, dst, non_blocking: bool):
self.writeline(f"{dst}.copy_({src}, {non_blocking})")
def codegen_multi_output(self, name, value):
self.writeline(f"{self.declare}{name} = {value}{self.ending}")
def codegen_multi_output(self, node: ir.MultiOutput):
result_name = node.get_name()
arg_name = node.inputs[0].get_name()
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
@ -1596,16 +1801,46 @@ class PythonWrapperCodegen(CodeGen):
metadata: Optional[str] = None,
gpu: bool = True,
cpp_definition: Optional[str] = None,
):
self.writeline(
KernelDefinitionLine(
self,
kernel_name,
kernel_body,
metadata=metadata,
gpu=gpu,
cpp_definition=cpp_definition,
)
)
def _format_kernel_definition(
self, kernel_name: str, kernel_body: str, metadata: Optional[str] = None
):
metadata_comment = f"{metadata}\n" if metadata else ""
body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}"
return body
def _define_kernel_helper(
self,
kernel_name: str,
kernel_body: str,
metadata: Optional[str] = None,
gpu: bool = True,
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time:
# Skip inserting comments for the autotune block as they may contain cpp style comments
body = f"\n\n{kernel_name} = {kernel_body}"
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=None
)
self.kernel_autotune_defs.splice(body)
if V.graph.cpp_wrapper:
# For cpp wrapper, no need to continue codegen for the main body
return
metadata_comment = f"{metadata}\n" if metadata else ""
body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}"
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
self.header.splice(body)
def define_subgraph_launcher_fn(self, fn_code: str):
@ -1997,10 +2232,10 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(raw_arg, ir.TMADescriptor):
# first we generate the underlying buffer
buf_name = raw_arg.tensor.get_name()
buf = V.graph.get_buffer(buf_name)
elif V.graph.try_get_buffer(arg) is not None:
buf = self.args_to_buffers[arg]
elif self.args_to_buffers.get(arg):
buf_name = arg
buf = V.graph.get_buffer(arg)
buf = self.args_to_buffers[arg]
else:
assert raw_arg is not None, (
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
@ -2009,6 +2244,7 @@ class PythonWrapperCodegen(CodeGen):
buf = raw_arg
self.kernel_autotune_tmp_arg_idx += 1
assert buf is not None, f"Failed to find a buffer for arg {arg}"
size = tuple(
V.graph.sizevars.atomically_apply_size_hint(
e,
@ -2103,6 +2339,48 @@ class PythonWrapperCodegen(CodeGen):
triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True,
and C++ when gpu=False.
"""
# Store buffers corresponding to each call arg.
# This is used to generate example args for autotuning later on.
self.args_to_buffers.update(
{
arg: V.graph.try_get_buffer(arg)
for arg in call_args
if isinstance(arg, str)
}
)
device = device or V.graph.get_current_device_or_throw()
self.writeline(
KernelCallLine(
self,
kernel_name=kernel_name,
call_args=call_args,
raw_keys=raw_keys,
raw_args=raw_args,
arg_types=arg_types,
triton=triton,
triton_meta=triton_meta,
device=device,
graph_name=V.graph.name,
original_fxnode_name=original_fxnode_name,
)
)
def _generate_kernel_call_helper(
self,
kernel_name: str,
call_args,
*,
device=None,
triton=True,
arg_types=None,
raw_keys=None,
raw_args=None,
triton_meta=None,
graph_name="",
original_fxnode_name=None,
):
device = device or V.graph.get_current_device_or_throw()
if not (triton or device.type != "cpu"):
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
@ -2111,7 +2389,7 @@ class PythonWrapperCodegen(CodeGen):
call_args_str = self.prepare_triton_kernel_call(call_args)
call_args_str = ", ".join(call_args_str)
stream_name = PythonWrapperCodegen.write_get_raw_stream(
self, device.index, V.graph
self, device.index, graph_name
)
if not triton:
stream_ptr = f"c_void_p({stream_name})"
@ -2209,6 +2487,7 @@ class PythonWrapperCodegen(CodeGen):
debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
with debug_printer_manager:
self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})")
self.write_triton_header_once()
def writeline(self, line):
self.lines.append(line)
@ -2297,6 +2576,9 @@ class PythonWrapperCodegen(CodeGen):
out = out + f".as_strided({codegen_shape_tuple}, {codegen_stride_tuple})"
return out
def make_comment(self, line):
self.writeline(CommentLine(line))
def make_tensor_alias(self, new_name, old_name, comment=""):
return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}"
@ -2361,10 +2643,12 @@ class PythonWrapperCodegen(CodeGen):
assert isinstance(layout.view, ir.ReinterpretView), (
f"unexpected {type(layout.view)}: {layout.view}"
)
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
self.codegen_allocation(layout.view.data.data)
self.codegen_deferred_allocation(name, layout.view)
box = layout.view.data
assert isinstance(box, ir.StorageBox), type(box)
input_buffer = box.data
assert isinstance(input_buffer, ir.Buffer), type(box)
self.codegen_allocation(input_buffer)
self.writeline(ReinterpretLine(self, input_buffer, buffer, layout))
return
if isinstance(layout, ir.CommBufferLayout):
@ -2378,7 +2662,7 @@ class PythonWrapperCodegen(CodeGen):
# can be freed but not reused
if isinstance(buffer, (ir.InputBuffer, ir.TorchBindObject)):
self.writeline(self.make_buffer_free(buffer))
self.writeline(FreeLine(self, buffer))
return
if isinstance(buffer.get_output_spec(), ir.CommBufferLayout):

View File

@ -5080,7 +5080,7 @@ class ExternKernel(InputsKernel):
def codegen_comment(self, wrapper) -> None: # type: ignore[no-untyped-def]
origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
if origin_str:
wrapper.writeline(origin_str)
wrapper.make_comment(origin_str)
def codegen(self, wrapper): # type: ignore[no-untyped-def]
raise NotImplementedError
@ -5786,25 +5786,7 @@ class ExternKernel(InputsKernel):
@ir_dataclass(frozen=False)
class ExternKernelOut(ExternKernel):
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs(skip_out=True)]
kernel_name = self.get_kernel_name()
if (
V.graph.cpp_wrapper
and self.cpp_kernel_name == "torch::inductor::_mm_plus_mm"
):
# For https://github.com/pytorch/pytorch/issues/128474
kernel_name = "aoti_torch__mm_plus_mm_out"
else:
kernel_name = self.get_kernel_name()
device = d.type if (d := self.get_device()) else V.graph.device_type
wrapper.generate_extern_kernel_out(
kernel_name,
self.codegen_reference(),
self.output_view.codegen_reference() if self.output_view else None,
args,
device,
)
wrapper.generate_extern_kernel_out(self)
def __init__( # type: ignore[no-untyped-def]
self,
@ -5859,11 +5841,7 @@ class RandomSeeds(ExternKernelOut):
class ExternKernelAlloc(ExternKernel):
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
wrapper.generate_extern_kernel_alloc(self)
def __init__( # type: ignore[no-untyped-def]
self,
@ -6967,7 +6945,7 @@ class FallbackKernel(ExternKernelAlloc):
# dispatch.
do_runtime_dispatch()
else:
V.graph.wrapper_code.generate_fallback_kernel(self, args)
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
@ -7118,32 +7096,8 @@ class MultiOutputLayout(OutputSpec):
class MultiOutput(ExternKernel):
# Given an input MultiOutputLayout buffer, indexes out an actual buffer
# from that result. This doesn't actually produce multiple outputs,
# that's MultiOutputLayout!
def codegen_list_tuple_access(self, basename, indices): # type: ignore[no-untyped-def]
if len(indices) > 0:
itype, i = indices[0]
if issubclass(itype, list):
return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
elif issubclass(itype, tuple):
# cpp wrapper code needs to use std::get<> to access a tuple
tuple_access = V.graph.wrapper_code.codegen_tuple_access(
basename, self.get_name(), str(i)
)
return self.codegen_list_tuple_access(tuple_access, indices[1:])
elif issubclass(itype, dict):
return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
else:
raise AssertionError("non supported index type: ", itype)
else:
return basename
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
wrapper.codegen_multi_output(
self.get_name(),
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
)
wrapper.codegen_multi_output(self)
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)