mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942. # Feature This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.  The IR currently supports the following ops: - All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.) - Reinterpret views (`ReinterpretLine`) - Kernel definitions (`KernelDefinitionLine`) - Calls to defined kernels (`KernelCallLine`) - Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`) - Ops with multiple outputs (`MultiOutputLine`) - Tensor cleanup at the end of a graph (`FreeLine`) - Leaving comments in code (`CommentLine`) There are two main motivations for this refactor: 1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`. 2. This design will hopefully promote stronger modularity and encapsulation. a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends. b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code. # Implementation details The implementation mainly consists of separating direct C++/Python codegen into two phases: 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do. 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code. The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros. # Test plan Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage. The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage. # Future directions These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files: - User-defined Triton kernels. - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.) - Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`. - Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants. These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes. One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor. Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458 Approved by: https://github.com/eellison
This commit is contained in:
parent
8f440a8e70
commit
c0a0761871
|
|
@ -72,7 +72,7 @@ class TestMemoryPlanning(TestCase):
|
|||
result, code = run_and_get_cpp_code(compiled, *args)
|
||||
|
||||
FileCheck().check(
|
||||
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_4, int_array_5, &tmp_tensor_handle_0)"
|
||||
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_2, int_array_3, &tmp_tensor_handle_0)"
|
||||
).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_0);").check(
|
||||
"auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_1);"
|
||||
).run(
|
||||
|
|
@ -97,17 +97,17 @@ class TestMemoryPlanning(TestCase):
|
|||
)
|
||||
|
||||
FileCheck().check(
|
||||
"int64_t int_array_2[] = {24L + align(12L*s77), };"
|
||||
).check_next("int64_t int_array_3[] = {1L, };").check_next(
|
||||
"int64_t int_array_0[] = {24L + align(12L*s77), };"
|
||||
).check_next("int64_t int_array_1[] = {1L, };").check_next(
|
||||
"AtenTensorHandle pool1_handle;"
|
||||
).check_next(
|
||||
"aoti_torch_empty_strided(1, int_array_2, int_array_3,"
|
||||
"aoti_torch_empty_strided(1, int_array_0, int_array_1,"
|
||||
).check_next(
|
||||
"RAIIAtenTensorHandle pool1(pool1_handle);"
|
||||
).check_next(
|
||||
"int64_t int_array_4[] = {s77, 3L};"
|
||||
"int64_t int_array_2[] = {s77, 3L};"
|
||||
).check_next(
|
||||
"int64_t int_array_5[] = {3L, 1L};"
|
||||
"int64_t int_array_3[] = {3L, 1L};"
|
||||
).check_next(
|
||||
"AtenTensorHandle tmp_tensor_handle_0;"
|
||||
).check_next(
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}"
|
||||
)
|
||||
|
||||
def generate_kernel_call(
|
||||
def _generate_kernel_call_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
|
|
@ -113,6 +113,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
raw_keys=None,
|
||||
raw_args=None,
|
||||
triton_meta=None,
|
||||
graph_name="",
|
||||
original_fxnode_name=None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -908,7 +909,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
self.prefix.splice(aot_mode_decls)
|
||||
self.prefix.splice(prior)
|
||||
|
||||
def define_kernel(
|
||||
def _define_kernel_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
kernel_body: str,
|
||||
|
|
@ -1160,7 +1161,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
if not is_inplace:
|
||||
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
|
||||
|
||||
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
||||
def _generate_extern_kernel_alloc_helper(self, extern_kernel, args):
|
||||
if getattr(extern_kernel, "outputs", None):
|
||||
# ir.ExternKernelAlloc may have outputs if it returns a tuple
|
||||
self.generate_c_shim_fallback_kernel(extern_kernel, args)
|
||||
|
|
@ -1209,10 +1210,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
for raii_handle in output_raii_handles:
|
||||
self.writeline(raii_handle)
|
||||
|
||||
def generate_fallback_kernel(self, fallback_kernel, args):
|
||||
self.generate_c_shim_fallback_kernel(fallback_kernel, args)
|
||||
|
||||
def generate_extern_kernel_out(
|
||||
def _generate_extern_kernel_out_helper(
|
||||
self,
|
||||
kernel: str,
|
||||
out: str,
|
||||
|
|
@ -1652,7 +1650,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));"
|
||||
)
|
||||
|
||||
def codegen_multi_output(self, name, value):
|
||||
def codegen_multi_output(self, node: ir.MultiOutput):
|
||||
# in the abi_compatible mode, outputs are retrieved by passing
|
||||
# output pointers, so we skip its codegen here.
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -87,7 +87,22 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
numel = buf.get_numel()
|
||||
self.prefix.writeline(f"assert_numel({name}, {numel});")
|
||||
|
||||
def generate_kernel_call(
|
||||
def generate_extern_kernel_alloc(self, *args, **kwargs):
|
||||
# Disable stack allocation for extern kernels.
|
||||
self.allow_stack_allocation = False
|
||||
super().generate_extern_kernel_alloc(*args, **kwargs)
|
||||
|
||||
def generate_extern_kernel_out(self, *args, **kwargs):
|
||||
# Disable stack allocation for extern kernels.
|
||||
self.allow_stack_allocation = False
|
||||
super().generate_extern_kernel_out(*args, **kwargs)
|
||||
|
||||
def generate_fallback_kernel(self, *args, **kwargs):
|
||||
# Disable stack allocation for extern kernels.
|
||||
self.allow_stack_allocation = False
|
||||
super().generate_fallback_kernel(*args, **kwargs)
|
||||
|
||||
def _generate_kernel_call_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
|
|
@ -98,6 +113,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
raw_keys=None,
|
||||
raw_args=None,
|
||||
triton_meta=None,
|
||||
graph_name="",
|
||||
original_fxnode_name=None,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
def write_tma_descriptor_helpers_once(self):
|
||||
self.header.splice(self.device_codegen.tma_descriptor_helpers())
|
||||
|
||||
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
|
||||
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
|
||||
name = f"stream{device_idx}"
|
||||
self.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
|
|
@ -294,7 +294,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
|
||||
super().codegen_inputs()
|
||||
|
||||
def define_kernel(
|
||||
def _define_kernel_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
kernel_body: str,
|
||||
|
|
@ -306,11 +306,11 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
self._kernel_name_to_body[kernel_name] = kernel_body
|
||||
if config.triton.autotune_at_compile_time:
|
||||
# Call PythonWrapperCodegen to create the autotune code block
|
||||
PythonWrapperCodegen.define_kernel(
|
||||
PythonWrapperCodegen._define_kernel_helper(
|
||||
self, kernel_name, kernel_body, metadata, gpu, cpp_definition
|
||||
)
|
||||
else:
|
||||
return CppWrapperCpu.define_kernel(
|
||||
return CppWrapperCpu._define_kernel_helper(
|
||||
self, kernel_name, kernel_body, metadata, gpu, cpp_definition
|
||||
)
|
||||
|
||||
|
|
@ -445,7 +445,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
|
||||
return ", ".join(new_args)
|
||||
|
||||
def generate_kernel_call(
|
||||
def _generate_kernel_call_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
|
|
@ -456,6 +456,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
raw_keys=None,
|
||||
raw_args=None,
|
||||
triton_meta=None,
|
||||
graph_name="",
|
||||
original_fxnode_name=None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -466,7 +467,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
device = device or V.graph.get_current_device_or_throw()
|
||||
if device.type == "cpu":
|
||||
# Even in CppWrapperGpu, we may see cpp kernels
|
||||
return CppWrapperCpu.generate_kernel_call(
|
||||
return CppWrapperCpu._generate_kernel_call_helper(
|
||||
self,
|
||||
kernel_name,
|
||||
call_args,
|
||||
|
|
@ -484,7 +485,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
and kernel_name not in self.kernel_autotune_names
|
||||
):
|
||||
# Call PythonWrapperCodegen to create the autotune code block
|
||||
PythonWrapperCodegen.generate_kernel_call(
|
||||
PythonWrapperCodegen._generate_kernel_call_helper(
|
||||
self,
|
||||
kernel_name,
|
||||
call_args,
|
||||
|
|
@ -500,7 +501,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
stream = (
|
||||
"stream"
|
||||
if V.graph.aot_mode
|
||||
else self.write_get_raw_stream(device.index, V.graph)
|
||||
else self.write_get_raw_stream(device.index, graph_name)
|
||||
)
|
||||
|
||||
if triton:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import functools
|
|||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import dtype as torch_dtype
|
||||
|
|
@ -57,6 +57,7 @@ class DebugPrinterManager:
|
|||
self,
|
||||
debug_printer_level,
|
||||
use_array_ref: bool,
|
||||
writeline: Optional[Callable[..., None]] = None,
|
||||
args_to_print_or_save: Optional[list[str]] = None,
|
||||
kernel_name: str = "",
|
||||
kernel=None,
|
||||
|
|
|
|||
|
|
@ -1637,7 +1637,9 @@ class HalideKernel(SIMDKernel):
|
|||
call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None]
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
if current_device.type == "cuda":
|
||||
stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph)
|
||||
stream_name = wrapper.write_get_raw_stream(
|
||||
current_device.index, V.graph.name
|
||||
)
|
||||
call_args.append(stream_name)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class MultiKernelState:
|
|||
|
||||
def __init__(self):
|
||||
self.subkernel_to_kernel_name = {}
|
||||
self.kernel_defs = IndentedBuffer()
|
||||
|
||||
def define_kernel(self, kernels):
|
||||
"""
|
||||
|
|
@ -116,7 +117,7 @@ class MultiKernelState:
|
|||
# the second pass of cpp-wrapper.
|
||||
return multi_kernel_name
|
||||
|
||||
buf = IndentedBuffer()
|
||||
buf = self.kernel_defs
|
||||
buf.writeline("")
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
|
|
@ -126,12 +127,10 @@ class MultiKernelState:
|
|||
buf.writeline(f"{name},")
|
||||
buf.writeline("])")
|
||||
|
||||
wrapper = V.graph.wrapper_code
|
||||
if config.triton.autotune_at_compile_time:
|
||||
wrapper.kernel_autotune_defs.splice(buf)
|
||||
wrapper.src_to_kernel["\n".join(kernel_names)] = multi_kernel_name
|
||||
else:
|
||||
wrapper.header.splice(buf)
|
||||
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
|
||||
multi_kernel_name
|
||||
)
|
||||
|
||||
return multi_kernel_name
|
||||
|
||||
|
|
|
|||
|
|
@ -4089,7 +4089,7 @@ class TritonScheduling(SIMDScheduling):
|
|||
wrapper = V.graph.wrapper_code
|
||||
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
if origins:
|
||||
wrapper.writeline(origins)
|
||||
wrapper.make_comment(origins)
|
||||
|
||||
if config.debug_fusion:
|
||||
from torch._inductor.scheduler import (
|
||||
|
|
@ -4107,7 +4107,7 @@ class TritonScheduling(SIMDScheduling):
|
|||
for n in node_schedule
|
||||
if isinstance(n, BaseSchedulerNode)
|
||||
]
|
||||
wrapper.writeline(
|
||||
wrapper.make_comment(
|
||||
f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,6 +75,8 @@ if TYPE_CHECKING:
|
|||
from ..graph import GraphLowering
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
pexpr = PythonPrinter().doprint
|
||||
|
||||
|
||||
|
|
@ -362,6 +364,14 @@ class EnterSubgraphLine(WrapperLine):
|
|||
code.do_indent()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommentLine(WrapperLine):
|
||||
line: LineContext
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
code.writeline(self.line)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExitSubgraphLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
|
|
@ -415,6 +425,102 @@ class ExitDeviceContextManagerLine(WrapperLine):
|
|||
code.do_unindent()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelAllocLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: ir.ExternKernelAlloc
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
node = self.node
|
||||
args = [*node.codegen_args(), *node.codegen_kwargs()]
|
||||
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelOutLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: ir.ExternKernelOut
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
node = self.node
|
||||
args = [*node.codegen_args(), *node.codegen_kwargs(skip_out=True)]
|
||||
kernel_name = node.get_kernel_name()
|
||||
if (
|
||||
V.graph.cpp_wrapper
|
||||
and node.cpp_kernel_name == "torch::inductor::_mm_plus_mm"
|
||||
):
|
||||
# For https://github.com/pytorch/pytorch/issues/128474
|
||||
kernel_name = "aoti_torch__mm_plus_mm_out"
|
||||
else:
|
||||
kernel_name = node.get_kernel_name()
|
||||
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||
self.wrapper._generate_extern_kernel_out_helper(
|
||||
kernel_name,
|
||||
node.codegen_reference(),
|
||||
node.output_view.codegen_reference() if node.output_view else None,
|
||||
args,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FreeLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: Union[BufferLike, ir.TorchBindObject]
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
assert self.node.get_name() not in V.graph.removed_buffers
|
||||
code.writeline(self.wrapper.make_buffer_free(self.node))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KernelCallLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
kernel_name: str
|
||||
call_args: tuple[Any, ...]
|
||||
raw_keys: tuple[Any, ...]
|
||||
raw_args: tuple[Any, ...]
|
||||
arg_types: list[str]
|
||||
triton: bool
|
||||
triton_meta: dict[str, Any]
|
||||
device: torch.device
|
||||
graph_name: str
|
||||
original_fxnode_name: str
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper._generate_kernel_call_helper(
|
||||
self.kernel_name,
|
||||
self.call_args,
|
||||
triton=self.triton,
|
||||
arg_types=self.arg_types,
|
||||
raw_keys=self.raw_keys,
|
||||
raw_args=self.raw_args,
|
||||
triton_meta=self.triton_meta,
|
||||
device=self.device,
|
||||
graph_name=self.graph_name,
|
||||
original_fxnode_name=self.original_fxnode_name,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KernelDefinitionLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
kernel_name: str
|
||||
kernel_body: str
|
||||
metadata: Optional[str] = None
|
||||
gpu: bool = True
|
||||
cpp_definition: Optional[str] = None
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper._define_kernel_helper(
|
||||
self.kernel_name,
|
||||
self.kernel_body,
|
||||
metadata=self.metadata,
|
||||
gpu=self.gpu,
|
||||
cpp_definition=self.cpp_definition,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryPlanningLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
|
|
@ -494,6 +600,23 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
|
|||
code.writeline(self.wrapper.make_buffer_free(self.node))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReinterpretLine(MemoryPlanningLine):
|
||||
node: BufferLike
|
||||
reused_as: BufferLike
|
||||
layout: ir.Layout
|
||||
|
||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||
return self
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
assert isinstance(self.layout, ir.NonOwningLayout)
|
||||
assert isinstance(self.layout.view, ir.ReinterpretView)
|
||||
self.wrapper.codegen_deferred_allocation(
|
||||
self.reused_as.get_name(), self.layout.view
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReuseLine(MemoryPlanningLine):
|
||||
node: BufferLike
|
||||
|
|
@ -598,7 +721,49 @@ class CommBufferFreeLine(CommBufferLine):
|
|||
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultiOutputLine(WrapperLine):
|
||||
"""
|
||||
Given a MultiOutputLayout buffer, indexes actual buffer(s) from the result.
|
||||
"""
|
||||
|
||||
wrapper: PythonWrapperCodegen
|
||||
result_name: str
|
||||
arg_name: str
|
||||
indices: Sequence[Any]
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
def codegen_list_tuple_access(basename, indices): # type: ignore[no-untyped-def]
|
||||
if len(indices) > 0:
|
||||
itype, i = indices[0]
|
||||
if issubclass(itype, list):
|
||||
return codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
|
||||
elif issubclass(itype, tuple):
|
||||
# cpp wrapper code needs to use std::get<> to access a tuple
|
||||
tuple_access = self.wrapper.codegen_tuple_access(
|
||||
basename, self.result_name, str(i)
|
||||
)
|
||||
return codegen_list_tuple_access(tuple_access, indices[1:])
|
||||
elif issubclass(itype, dict):
|
||||
return codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
|
||||
else:
|
||||
raise AssertionError("non supported index type: ", itype)
|
||||
else:
|
||||
return basename
|
||||
|
||||
value = codegen_list_tuple_access(self.arg_name, self.indices)
|
||||
code.writeline(
|
||||
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OutputLine(WrapperLine):
|
||||
buffers: tuple[BufferLike, ...]
|
||||
|
||||
|
||||
BufferName = str
|
||||
Line = Union[MemoryPlanningLine, LineContext]
|
||||
|
||||
|
||||
class PythonWrapperCodegen(CodeGen):
|
||||
|
|
@ -609,6 +774,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self._names_iter: Iterator[int] = count()
|
||||
self.args_to_buffers: dict[
|
||||
str, Union[None, ir.TensorBox, ir.Buffer, ir.TorchBindObject]
|
||||
] = {}
|
||||
self.imports = IndentedBuffer()
|
||||
self.header = IndentedBuffer()
|
||||
self.prefix = IndentedBuffer()
|
||||
|
|
@ -627,7 +795,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
# pre-existing kernel for it
|
||||
self.src_to_kernel: dict[str, str] = {}
|
||||
self.kernel_numel_expr: OrderedSet[tuple[str, GraphLowering]] = OrderedSet()
|
||||
self.lines: list[Union[MemoryPlanningLine, LineContext]] = []
|
||||
self.lines: list[Line] = []
|
||||
self.declare = ""
|
||||
self.declare_maybe_reference = ""
|
||||
self.ending = ""
|
||||
|
|
@ -953,10 +1121,10 @@ class PythonWrapperCodegen(CodeGen):
|
|||
if config.nan_asserts:
|
||||
self.codegen_input_nan_asserts()
|
||||
|
||||
# this function (and below) takes a graph as input so
|
||||
# this function (and below) takes the graph name as input so
|
||||
# that stream caching happens per graph instance. this
|
||||
# is important for nested subgraph codegening.
|
||||
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
|
||||
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
|
||||
self.write_get_raw_stream_header_once()
|
||||
name = f"stream{device_idx}"
|
||||
if config.triton.autotune_at_compile_time:
|
||||
|
|
@ -1039,10 +1207,16 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def generate_end(self, result: IndentedBuffer) -> None:
|
||||
return
|
||||
|
||||
def generate_fallback_kernel(self, fallback_kernel, args):
|
||||
self.generate_extern_kernel_alloc(fallback_kernel, args)
|
||||
def generate_fallback_kernel(self, node: ir.FallbackKernel):
|
||||
self.writeline(ExternKernelAllocLine(self, node))
|
||||
|
||||
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
||||
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
|
||||
node.codegen_comment(self)
|
||||
self.writeline(ExternKernelAllocLine(self, node))
|
||||
if isinstance(node.layout, ir.Layout):
|
||||
node.codegen_size_asserts(self)
|
||||
|
||||
def _generate_extern_kernel_alloc_helper(self, extern_kernel, args):
|
||||
# If it's a NoneLayout then the extern_kernel should essentially be
|
||||
# treated as if it doesn't return anything
|
||||
no_return = isinstance(extern_kernel.layout, ir.NoneLayout)
|
||||
|
|
@ -1072,6 +1246,13 @@ class PythonWrapperCodegen(CodeGen):
|
|||
)
|
||||
|
||||
def generate_extern_kernel_out(
|
||||
self,
|
||||
node: ir.ExternKernelOut,
|
||||
) -> None:
|
||||
node.codegen_comment(self)
|
||||
self.writeline(ExternKernelOutLine(self, node))
|
||||
|
||||
def _generate_extern_kernel_out_helper(
|
||||
self,
|
||||
kernel: str,
|
||||
out: str,
|
||||
|
|
@ -1159,20 +1340,25 @@ class PythonWrapperCodegen(CodeGen):
|
|||
else:
|
||||
return 1
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_writeline(self, new: Callable[..., None]) -> Iterator[Callable[..., None]]:
|
||||
old = self.writeline
|
||||
try:
|
||||
self.writeline = new # type: ignore[method-assign]
|
||||
yield new
|
||||
finally:
|
||||
self.writeline = old # type: ignore[method-assign]
|
||||
|
||||
def _write_multi_kernel_defs(self) -> None:
|
||||
kernel_defs = self.multi_kernel_state.kernel_defs
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_defs.splice(kernel_defs)
|
||||
else:
|
||||
self.header.splice(kernel_defs)
|
||||
|
||||
def _generate(self, is_inference):
|
||||
if config.profile_bandwidth:
|
||||
self.write_triton_header_once()
|
||||
result = IndentedBuffer()
|
||||
result.splice(self.imports)
|
||||
result.writeline("")
|
||||
result.splice(self.header)
|
||||
# We do not want the cpp header for intermediate const graph. Headers would be
|
||||
# rendered by the main module instead.
|
||||
if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
|
||||
result = IndentedBuffer()
|
||||
|
||||
# Add subgraph definitions to the result
|
||||
result.splice(self.subgraph_definitions)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(self.wrapper_call.indent())
|
||||
|
|
@ -1190,11 +1376,16 @@ class PythonWrapperCodegen(CodeGen):
|
|||
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
|
||||
self.generate_reset_kernel_saved_flags()
|
||||
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
line.codegen(self.wrapper_call)
|
||||
else:
|
||||
self.wrapper_call.writeline(line)
|
||||
# At this point, we shouldn't generate any new memory planning lines.
|
||||
# Override writeline to point at the wrapper call, in case it gets called.
|
||||
with self.set_writeline(self.wrapper_call.writeline):
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
line.codegen(self.wrapper_call)
|
||||
else:
|
||||
self.wrapper_call.writeline(line)
|
||||
|
||||
self._write_multi_kernel_defs()
|
||||
|
||||
output_refs = self.get_output_refs()
|
||||
self.mark_output_type()
|
||||
|
|
@ -1217,6 +1408,18 @@ class PythonWrapperCodegen(CodeGen):
|
|||
)
|
||||
self.generate_return(output_refs)
|
||||
|
||||
# Assemble the final code from sections.
|
||||
result = IndentedBuffer()
|
||||
result.splice(self.imports)
|
||||
result.writeline("")
|
||||
result.splice(self.header)
|
||||
# We do not want the cpp header for intermediate const graph. Headers would be
|
||||
# rendered by the main module instead.
|
||||
if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
|
||||
result = IndentedBuffer()
|
||||
|
||||
# Add subgraph definitions to the result
|
||||
result.splice(self.subgraph_definitions)
|
||||
self.finalize_prefix()
|
||||
result.splice(self.prefix)
|
||||
|
||||
|
|
@ -1454,8 +1657,10 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def codegen_device_copy(self, src, dst, non_blocking: bool):
|
||||
self.writeline(f"{dst}.copy_({src}, {non_blocking})")
|
||||
|
||||
def codegen_multi_output(self, name, value):
|
||||
self.writeline(f"{self.declare}{name} = {value}{self.ending}")
|
||||
def codegen_multi_output(self, node: ir.MultiOutput):
|
||||
result_name = node.get_name()
|
||||
arg_name = node.inputs[0].get_name()
|
||||
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
||||
|
||||
def codegen_dynamic_scalar(self, node):
|
||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||
|
|
@ -1596,16 +1801,46 @@ class PythonWrapperCodegen(CodeGen):
|
|||
metadata: Optional[str] = None,
|
||||
gpu: bool = True,
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
self.writeline(
|
||||
KernelDefinitionLine(
|
||||
self,
|
||||
kernel_name,
|
||||
kernel_body,
|
||||
metadata=metadata,
|
||||
gpu=gpu,
|
||||
cpp_definition=cpp_definition,
|
||||
)
|
||||
)
|
||||
|
||||
def _format_kernel_definition(
|
||||
self, kernel_name: str, kernel_body: str, metadata: Optional[str] = None
|
||||
):
|
||||
metadata_comment = f"{metadata}\n" if metadata else ""
|
||||
body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}"
|
||||
return body
|
||||
|
||||
def _define_kernel_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
kernel_body: str,
|
||||
metadata: Optional[str] = None,
|
||||
gpu: bool = True,
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
if config.triton.autotune_at_compile_time:
|
||||
# Skip inserting comments for the autotune block as they may contain cpp style comments
|
||||
body = f"\n\n{kernel_name} = {kernel_body}"
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=None
|
||||
)
|
||||
self.kernel_autotune_defs.splice(body)
|
||||
if V.graph.cpp_wrapper:
|
||||
# For cpp wrapper, no need to continue codegen for the main body
|
||||
return
|
||||
metadata_comment = f"{metadata}\n" if metadata else ""
|
||||
body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}"
|
||||
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
self.header.splice(body)
|
||||
|
||||
def define_subgraph_launcher_fn(self, fn_code: str):
|
||||
|
|
@ -1997,10 +2232,10 @@ class PythonWrapperCodegen(CodeGen):
|
|||
if isinstance(raw_arg, ir.TMADescriptor):
|
||||
# first we generate the underlying buffer
|
||||
buf_name = raw_arg.tensor.get_name()
|
||||
buf = V.graph.get_buffer(buf_name)
|
||||
elif V.graph.try_get_buffer(arg) is not None:
|
||||
buf = self.args_to_buffers[arg]
|
||||
elif self.args_to_buffers.get(arg):
|
||||
buf_name = arg
|
||||
buf = V.graph.get_buffer(arg)
|
||||
buf = self.args_to_buffers[arg]
|
||||
else:
|
||||
assert raw_arg is not None, (
|
||||
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
|
||||
|
|
@ -2009,6 +2244,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
buf = raw_arg
|
||||
self.kernel_autotune_tmp_arg_idx += 1
|
||||
|
||||
assert buf is not None, f"Failed to find a buffer for arg {arg}"
|
||||
size = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
e,
|
||||
|
|
@ -2103,6 +2339,48 @@ class PythonWrapperCodegen(CodeGen):
|
|||
triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True,
|
||||
and C++ when gpu=False.
|
||||
"""
|
||||
|
||||
# Store buffers corresponding to each call arg.
|
||||
# This is used to generate example args for autotuning later on.
|
||||
self.args_to_buffers.update(
|
||||
{
|
||||
arg: V.graph.try_get_buffer(arg)
|
||||
for arg in call_args
|
||||
if isinstance(arg, str)
|
||||
}
|
||||
)
|
||||
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
self.writeline(
|
||||
KernelCallLine(
|
||||
self,
|
||||
kernel_name=kernel_name,
|
||||
call_args=call_args,
|
||||
raw_keys=raw_keys,
|
||||
raw_args=raw_args,
|
||||
arg_types=arg_types,
|
||||
triton=triton,
|
||||
triton_meta=triton_meta,
|
||||
device=device,
|
||||
graph_name=V.graph.name,
|
||||
original_fxnode_name=original_fxnode_name,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_kernel_call_helper(
|
||||
self,
|
||||
kernel_name: str,
|
||||
call_args,
|
||||
*,
|
||||
device=None,
|
||||
triton=True,
|
||||
arg_types=None,
|
||||
raw_keys=None,
|
||||
raw_args=None,
|
||||
triton_meta=None,
|
||||
graph_name="",
|
||||
original_fxnode_name=None,
|
||||
):
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
if not (triton or device.type != "cpu"):
|
||||
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
|
||||
|
|
@ -2111,7 +2389,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
call_args_str = self.prepare_triton_kernel_call(call_args)
|
||||
call_args_str = ", ".join(call_args_str)
|
||||
stream_name = PythonWrapperCodegen.write_get_raw_stream(
|
||||
self, device.index, V.graph
|
||||
self, device.index, graph_name
|
||||
)
|
||||
if not triton:
|
||||
stream_ptr = f"c_void_p({stream_name})"
|
||||
|
|
@ -2209,6 +2487,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
|
||||
with debug_printer_manager:
|
||||
self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})")
|
||||
self.write_triton_header_once()
|
||||
|
||||
def writeline(self, line):
|
||||
self.lines.append(line)
|
||||
|
|
@ -2297,6 +2576,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
out = out + f".as_strided({codegen_shape_tuple}, {codegen_stride_tuple})"
|
||||
return out
|
||||
|
||||
def make_comment(self, line):
|
||||
self.writeline(CommentLine(line))
|
||||
|
||||
def make_tensor_alias(self, new_name, old_name, comment=""):
|
||||
return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}"
|
||||
|
||||
|
|
@ -2361,10 +2643,12 @@ class PythonWrapperCodegen(CodeGen):
|
|||
assert isinstance(layout.view, ir.ReinterpretView), (
|
||||
f"unexpected {type(layout.view)}: {layout.view}"
|
||||
)
|
||||
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
|
||||
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
|
||||
self.codegen_allocation(layout.view.data.data)
|
||||
self.codegen_deferred_allocation(name, layout.view)
|
||||
box = layout.view.data
|
||||
assert isinstance(box, ir.StorageBox), type(box)
|
||||
input_buffer = box.data
|
||||
assert isinstance(input_buffer, ir.Buffer), type(box)
|
||||
self.codegen_allocation(input_buffer)
|
||||
self.writeline(ReinterpretLine(self, input_buffer, buffer, layout))
|
||||
return
|
||||
|
||||
if isinstance(layout, ir.CommBufferLayout):
|
||||
|
|
@ -2378,7 +2662,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
|
||||
# can be freed but not reused
|
||||
if isinstance(buffer, (ir.InputBuffer, ir.TorchBindObject)):
|
||||
self.writeline(self.make_buffer_free(buffer))
|
||||
self.writeline(FreeLine(self, buffer))
|
||||
return
|
||||
|
||||
if isinstance(buffer.get_output_spec(), ir.CommBufferLayout):
|
||||
|
|
|
|||
|
|
@ -5080,7 +5080,7 @@ class ExternKernel(InputsKernel):
|
|||
def codegen_comment(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
|
||||
if origin_str:
|
||||
wrapper.writeline(origin_str)
|
||||
wrapper.make_comment(origin_str)
|
||||
|
||||
def codegen(self, wrapper): # type: ignore[no-untyped-def]
|
||||
raise NotImplementedError
|
||||
|
|
@ -5786,25 +5786,7 @@ class ExternKernel(InputsKernel):
|
|||
@ir_dataclass(frozen=False)
|
||||
class ExternKernelOut(ExternKernel):
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
self.codegen_comment(wrapper)
|
||||
args = [*self.codegen_args(), *self.codegen_kwargs(skip_out=True)]
|
||||
kernel_name = self.get_kernel_name()
|
||||
if (
|
||||
V.graph.cpp_wrapper
|
||||
and self.cpp_kernel_name == "torch::inductor::_mm_plus_mm"
|
||||
):
|
||||
# For https://github.com/pytorch/pytorch/issues/128474
|
||||
kernel_name = "aoti_torch__mm_plus_mm_out"
|
||||
else:
|
||||
kernel_name = self.get_kernel_name()
|
||||
device = d.type if (d := self.get_device()) else V.graph.device_type
|
||||
wrapper.generate_extern_kernel_out(
|
||||
kernel_name,
|
||||
self.codegen_reference(),
|
||||
self.output_view.codegen_reference() if self.output_view else None,
|
||||
args,
|
||||
device,
|
||||
)
|
||||
wrapper.generate_extern_kernel_out(self)
|
||||
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
|
|
@ -5859,11 +5841,7 @@ class RandomSeeds(ExternKernelOut):
|
|||
|
||||
class ExternKernelAlloc(ExternKernel):
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
self.codegen_comment(wrapper)
|
||||
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
||||
V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
wrapper.generate_extern_kernel_alloc(self)
|
||||
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
|
|
@ -6979,7 +6957,7 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
# dispatch.
|
||||
do_runtime_dispatch()
|
||||
else:
|
||||
V.graph.wrapper_code.generate_fallback_kernel(self, args)
|
||||
wrapper.generate_fallback_kernel(self)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
|
|
@ -7130,32 +7108,8 @@ class MultiOutputLayout(OutputSpec):
|
|||
|
||||
|
||||
class MultiOutput(ExternKernel):
|
||||
# Given an input MultiOutputLayout buffer, indexes out an actual buffer
|
||||
# from that result. This doesn't actually produce multiple outputs,
|
||||
# that's MultiOutputLayout!
|
||||
def codegen_list_tuple_access(self, basename, indices): # type: ignore[no-untyped-def]
|
||||
if len(indices) > 0:
|
||||
itype, i = indices[0]
|
||||
if issubclass(itype, list):
|
||||
return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
|
||||
elif issubclass(itype, tuple):
|
||||
# cpp wrapper code needs to use std::get<> to access a tuple
|
||||
tuple_access = V.graph.wrapper_code.codegen_tuple_access(
|
||||
basename, self.get_name(), str(i)
|
||||
)
|
||||
return self.codegen_list_tuple_access(tuple_access, indices[1:])
|
||||
elif issubclass(itype, dict):
|
||||
return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
|
||||
else:
|
||||
raise AssertionError("non supported index type: ", itype)
|
||||
else:
|
||||
return basename
|
||||
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
wrapper.codegen_multi_output(
|
||||
self.get_name(),
|
||||
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
|
||||
)
|
||||
wrapper.codegen_multi_output(self)
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user