mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
A PR to generate benchmark code for individual triton kernels. We can explore improving autotuning with the saved compiled kernel directly. This potentially can speedup our iteration and separate the concern with the upstream components that generate the compiled module. Since I'm still ramping up on inductor, I'll reflect what I learned here so people can correct me if I'm wrong. In inductor, WrapperCodeGen class is used to generate the compiled module for CUDA (or triton). Here is an example compiled module for a toy model like: `def f(x): return sin(x) + cos(x)` https://gist.github.com/shunting314/c6ed9f571919e3b414166f1696dcc61b . A compiled module contains the following part: - various triton kernels - a wrapper (or a method named call . The name is hardcoded) that calls the triton kernels and potentially ATen kernels to efficiently do the same work as the original Fx graph being compiled by inductor - some utility code that generate random inputs and run the wrapper The triton kernels in the compiled module are annotated with decorator like pointwise which is used for autotuning. This PR add a config so enabling it will just trigger the path of the compiled module being printed. It can be controlled from environment variable as well. The path to each compiled triton kernel is added as comment in the compiled module. E.g. ``` # kernel path: /tmp/torchinductor_shunting/gn/cgn6x3mqoltu7q77gjnu2elwfupinsvcovqwibc6fhsoiy34tvga.py triton__0 = async_compile.triton(''' import triton import triton.language as tl ... """) ```` Example command: ``` TORCHINDUCTOR_OUTPUT_COMPILED_MODULE_PATH=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --training --dashboard --only AlbertForMaskedLM --disable-cudagraphs ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/95506 Approved by: https://github.com/Chillee
820 lines
28 KiB
Python
820 lines
28 KiB
Python
import collections
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import hashlib
|
|
from itertools import count
|
|
from typing import Any, Dict, List
|
|
|
|
import sympy
|
|
|
|
from torch._dynamo.utils import dynamo_timed
|
|
|
|
from .. import codecache, config, ir
|
|
from ..codecache import code_hash, cpp_compile_command, get_code_path
|
|
from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product
|
|
from ..virtualized import V
|
|
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
|
|
|
|
pexpr = PythonPrinter().doprint
|
|
|
|
|
|
def buffer_reuse_key(node: ir.Buffer):
|
|
size = node.get_size()
|
|
stride = node.get_stride()
|
|
last_element = sympy_dot([s - 1 for s in size], stride)
|
|
return (
|
|
node.get_device(),
|
|
node.get_dtype(),
|
|
V.graph.sizevars.simplify(sympy_product(size)),
|
|
# Detect gaps in tensor storage caused by strides
|
|
V.graph.sizevars.size_hint(last_element),
|
|
)
|
|
|
|
|
|
def make_buffer_reuse(old, new, del_func, declare, ending, as_strided):
|
|
assert old.get_dtype() == new.get_dtype()
|
|
del_line = ""
|
|
if old.get_name() not in V.graph.get_output_names():
|
|
del_line = del_func(old.get_name())
|
|
if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
|
|
return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}"
|
|
|
|
return (
|
|
f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, "
|
|
f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, "
|
|
f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}"
|
|
)
|
|
|
|
|
|
def make_buffer_allocation(buffer):
|
|
device = buffer.get_device()
|
|
dtype = buffer.get_dtype()
|
|
shape = tuple(buffer.get_size())
|
|
stride = tuple(buffer.get_stride())
|
|
return (
|
|
f"{buffer.get_name()} = empty_strided("
|
|
f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
|
|
f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
|
|
f"device='{device.type}', dtype={dtype})"
|
|
)
|
|
|
|
|
|
def make_cpp_buffer_allocation(buffer):
|
|
from .cpp import DTYPE_TO_ATEN
|
|
|
|
# TODO: map layout and device here
|
|
dtype = buffer.get_dtype()
|
|
shape = tuple(buffer.get_size())
|
|
stride = tuple(buffer.get_stride())
|
|
return (
|
|
f"auto {buffer.get_name()} = at::empty_strided("
|
|
f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
|
|
f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
|
|
f"{DTYPE_TO_ATEN[dtype]}); "
|
|
)
|
|
|
|
|
|
class MemoryPlanningState:
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.reuse_pool: Dict[
|
|
Any, List["FreeIfNotReusedLine"]
|
|
] = collections.defaultdict(list)
|
|
|
|
def __contains__(self, key):
|
|
return bool(self.reuse_pool.get(key, None))
|
|
|
|
def pop(self, key) -> "FreeIfNotReusedLine":
|
|
item = self.reuse_pool[key].pop()
|
|
assert not item.is_reused
|
|
return item
|
|
|
|
def push(self, key, item: "FreeIfNotReusedLine"):
|
|
assert not item.is_reused
|
|
self.reuse_pool[key].append(item)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class EnterCudaDeviceContextManagerLine:
|
|
device_idx: int
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
# Note _DeviceGuard has less overhead than device, but only accepts
|
|
# integers
|
|
code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")
|
|
|
|
|
|
class ExitCudaDeviceContextManagerLine:
|
|
pass
|
|
|
|
|
|
class MemoryPlanningLine:
|
|
def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
|
|
"""First pass to find reuse"""
|
|
return self
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
"""Second pass to output code"""
|
|
pass
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AllocateLine(MemoryPlanningLine):
|
|
node: ir.Buffer
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
if self.node.get_name() in V.graph.removed_buffers:
|
|
return NullLine()
|
|
|
|
# try to reuse a recently freed buffer
|
|
key = buffer_reuse_key(self.node)
|
|
if key in state:
|
|
free_line = state.pop(key)
|
|
free_line.is_reused = True
|
|
return ReuseLine(free_line.node, self.node)
|
|
|
|
return self
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
code.writeline(make_buffer_allocation(self.node))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CppAllocateLine(AllocateLine):
|
|
def plan(self, state: MemoryPlanningState):
|
|
if self.node.get_name() in V.graph.removed_buffers:
|
|
return NullLine()
|
|
|
|
# try to reuse a recently freed buffer
|
|
key = buffer_reuse_key(self.node)
|
|
|
|
if key in state:
|
|
free_line = state.pop(key)
|
|
free_line.is_reused = True
|
|
return CppReuseLine(free_line.node, self.node)
|
|
|
|
return self
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
code.writeline(make_cpp_buffer_allocation(self.node))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FreeIfNotReusedLine(MemoryPlanningLine):
|
|
node: ir.Buffer
|
|
is_reused: bool = False
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
assert not self.is_reused
|
|
if self.node.get_name() in V.graph.removed_buffers:
|
|
return NullLine()
|
|
state.push(buffer_reuse_key(self.node), self)
|
|
return self
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
if not self.is_reused:
|
|
code.writeline(f"del {self.node.get_name()}")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CppFreeIfNotReusedLine(FreeIfNotReusedLine):
|
|
node: ir.Buffer
|
|
is_reused: bool = False
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert (self.node.get_name()) not in V.graph.removed_buffers
|
|
if not self.is_reused:
|
|
code.writeline(f"{self.node.get_name()}.reset();")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ReuseLine(MemoryPlanningLine):
|
|
node: ir.Buffer
|
|
reused_as: ir.Buffer
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
|
return self
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
|
code.writeline(
|
|
make_buffer_reuse(
|
|
self.node,
|
|
self.reused_as,
|
|
del_func=lambda name: f"; del {name}",
|
|
declare="",
|
|
ending="",
|
|
as_strided="as_strided",
|
|
)
|
|
+ " # reuse"
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CppReuseLine(ReuseLine):
|
|
node: ir.Buffer
|
|
reused_as: ir.Buffer
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
|
code.writeline(
|
|
make_buffer_reuse(
|
|
self.node,
|
|
self.reused_as,
|
|
del_func=lambda name: f"; {name}.reset()",
|
|
declare="auto ",
|
|
ending=";",
|
|
as_strided="at::as_strided",
|
|
)
|
|
+ " // reuse"
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FreeLine(MemoryPlanningLine):
|
|
node: ir.Buffer
|
|
|
|
def plan(self, state: MemoryPlanningState):
|
|
if self.node.get_name() in V.graph.removed_buffers:
|
|
return NullLine()
|
|
return self
|
|
|
|
def codegen(self, code: IndentedBuffer):
|
|
assert self.node.get_name() not in V.graph.removed_buffers
|
|
code.writeline(f"del {self.node.get_name()}")
|
|
|
|
|
|
class NullLine(MemoryPlanningLine):
|
|
pass
|
|
|
|
|
|
class WrapperCodeGen(CodeGen):
|
|
"""
|
|
The outer wrapper that calls the kernels.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._names_iter = count()
|
|
self.header = IndentedBuffer()
|
|
self.prefix = IndentedBuffer()
|
|
self.wrapper_call = IndentedBuffer()
|
|
self.kernels = {}
|
|
self.lines = []
|
|
self.header.splice(
|
|
f"""
|
|
from ctypes import c_void_p, c_long
|
|
import torch
|
|
import math
|
|
import random
|
|
from torch import empty_strided, as_strided, device
|
|
from {codecache.__name__} import AsyncCompile
|
|
from torch._inductor.select_algorithm import extern_kernels
|
|
|
|
aten = torch.ops.aten
|
|
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
|
async_compile = AsyncCompile()
|
|
|
|
"""
|
|
)
|
|
|
|
if has_triton():
|
|
self.header.splice(
|
|
"""
|
|
import triton
|
|
import triton.language as tl
|
|
from torch._inductor.triton_ops.autotune import grid, start_graph, end_graph
|
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
|
"""
|
|
)
|
|
|
|
if config.triton.convolution != "aten":
|
|
self.header.splice(
|
|
"""
|
|
from torch._inductor.triton_ops.conv_perf_model import early_config_prune
|
|
from torch._inductor.triton_ops.conv_perf_model import estimate_conv_time
|
|
from torch._inductor.triton_ops.autotune import conv_heuristics
|
|
"""
|
|
)
|
|
|
|
self.write_prefix()
|
|
|
|
for name, value in V.graph.constants.items():
|
|
# include a hash so our code cache gives different constants different files
|
|
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
|
|
self.header.writeline(f"{name} = None # {hashed}")
|
|
|
|
self.allocated = set()
|
|
self.freed = set()
|
|
|
|
# maps from reusing buffer to reused buffer
|
|
self.reuses = dict()
|
|
|
|
self.write_get_cuda_stream = functools.lru_cache(None)(
|
|
self.write_get_cuda_stream
|
|
)
|
|
|
|
@functools.lru_cache(None)
|
|
def add_import_once(line):
|
|
self.header.writeline(line)
|
|
|
|
self.add_import_once = add_import_once
|
|
self._metas = {}
|
|
|
|
def add_meta_once(self, meta):
|
|
meta = repr(meta)
|
|
if meta not in self._metas:
|
|
var = f"meta{len(self._metas)}"
|
|
self._metas[meta] = var
|
|
self.header.writeline(f"{var} = {meta}")
|
|
return self._metas[meta]
|
|
|
|
@cache_on_self
|
|
def get_output_refs(self):
|
|
return [x.codegen_reference() for x in V.graph.graph_outputs]
|
|
|
|
def write_prefix(self):
|
|
self.prefix.splice(
|
|
"""
|
|
|
|
async_compile.wait(globals())
|
|
del async_compile
|
|
|
|
def call(args):
|
|
"""
|
|
)
|
|
with self.prefix.indent():
|
|
if config.triton.debug_sync_graph:
|
|
self.prefix.writeline("torch.cuda.synchronize()")
|
|
inp_len = len(V.graph.graph_inputs.keys())
|
|
if inp_len != 0:
|
|
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
|
|
self.prefix.writeline(f"{lhs} = args")
|
|
self.prefix.writeline("args.clear()")
|
|
for name in V.graph.randomness_seeds:
|
|
self.prefix.writeline(
|
|
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
|
|
)
|
|
V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs)
|
|
|
|
def append_precomputed_sizes_to_prefix(self):
|
|
with self.prefix.indent():
|
|
V.graph.sizevars.codegen_precomputed_sizes(self.prefix)
|
|
|
|
def write_get_cuda_stream(self, index):
|
|
name = f"stream{index}"
|
|
self.writeline(f"{name} = get_cuda_stream({index})")
|
|
return name
|
|
|
|
def next_kernel_suffix(self):
|
|
return f"{next(self._names_iter)}"
|
|
|
|
def write_allocate_line(self, buffer):
|
|
self.writeline(AllocateLine(buffer))
|
|
|
|
def get_deferred_line(self, name, layout):
|
|
return DeferredLine(
|
|
name, f"{name} = {layout.view.codegen_reference()} # alias"
|
|
)
|
|
|
|
def codegen_allocation(self, buffer):
|
|
name = buffer.get_name()
|
|
if name in V.graph.removed_buffers or name in self.allocated:
|
|
return
|
|
self.allocated.add(name)
|
|
|
|
if isinstance(
|
|
buffer,
|
|
(ir.ExternKernelAlloc, ir.MultiOutput),
|
|
):
|
|
return
|
|
|
|
layout = buffer.get_layout()
|
|
if isinstance(layout, ir.MutationLayout):
|
|
return
|
|
if isinstance(layout, ir.AliasedLayout):
|
|
assert isinstance(layout.view, ir.ReinterpretView)
|
|
if not layout.maybe_guard_aligned():
|
|
V.graph.unaligned_buffers.add(name)
|
|
self.codegen_allocation(layout.view.data)
|
|
allocation = self.get_deferred_line(name, layout)
|
|
self.writeline(allocation)
|
|
return
|
|
|
|
self.write_allocate_line(buffer)
|
|
|
|
def write_del_line(self, name):
|
|
self.writeline(f"del {name}")
|
|
|
|
def write_free_if_not_reused_line(self, buffer):
|
|
self.writeline(FreeIfNotReusedLine(buffer))
|
|
|
|
def codegen_free(self, buffer):
|
|
name = buffer.get_name()
|
|
|
|
# can be freed but not reused
|
|
if isinstance(buffer, ir.InputBuffer):
|
|
self.write_del_line(name)
|
|
return
|
|
|
|
if not self.can_reuse(buffer):
|
|
return
|
|
self.freed.add(name)
|
|
|
|
layout = buffer.get_layout()
|
|
if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
|
|
self.write_del_line(name)
|
|
return
|
|
|
|
self.write_free_if_not_reused_line(buffer)
|
|
|
|
def can_reuse(self, buffer):
|
|
name = buffer.get_name()
|
|
if (
|
|
name in V.graph.removed_buffers
|
|
or name in V.graph.graph_inputs
|
|
or name in V.graph.constants
|
|
or name in self.freed
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def did_reuse(self, buffer, reused_buffer):
|
|
# Check whether a given buffer was reused by a possible reuser in the wrapper codegen
|
|
# Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
|
|
return (
|
|
buffer.get_name() in self.reuses
|
|
and self.reuses[buffer.get_name()] == reused_buffer.get_name()
|
|
)
|
|
|
|
def write_reuse_line(self, input_buffer, output_buffer):
|
|
self.writeline(ReuseLine(input_buffer, output_buffer))
|
|
|
|
def codegen_inplace_reuse(self, input_buffer, output_buffer):
|
|
assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
|
|
self.codegen_allocation(input_buffer)
|
|
self.freed.add(input_buffer.get_name())
|
|
self.allocated.add(output_buffer.get_name())
|
|
self.reuses[output_buffer.get_name()] = input_buffer.get_name()
|
|
self.write_reuse_line(input_buffer, output_buffer)
|
|
|
|
def codegen_cuda_device_guard_enter(self, device_idx):
|
|
self.lines.append(EnterCudaDeviceContextManagerLine(device_idx))
|
|
|
|
def codegen_cuda_device_guard_exit(self):
|
|
self.lines.append(ExitCudaDeviceContextManagerLine())
|
|
|
|
def generate_return(self, output_refs):
|
|
if output_refs:
|
|
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
|
|
else:
|
|
self.wrapper_call.writeline("return ()")
|
|
|
|
def generate_end(self, result):
|
|
return
|
|
|
|
def generate_extern_kernel_out(
|
|
self, output_view, codegen_reference, args, kernel, cpp_kernel
|
|
):
|
|
if output_view:
|
|
args.append(f"out={output_view.codegen_reference()}")
|
|
else:
|
|
args.append(f"out={codegen_reference}")
|
|
self.writeline(f"{kernel}({', '.join(args)})")
|
|
|
|
@dynamo_timed
|
|
def generate(self):
|
|
result = IndentedBuffer()
|
|
result.splice(self.header)
|
|
|
|
out_names = V.graph.get_output_names()
|
|
with contextlib.ExitStack() as stack:
|
|
stack.enter_context(self.wrapper_call.indent())
|
|
if config.profiler_mark_wrapper_call:
|
|
self.wrapper_call.writeline(
|
|
"from torch.profiler import record_function"
|
|
)
|
|
self.wrapper_call.writeline(
|
|
"with record_function('inductor_wrapper_call'):"
|
|
)
|
|
stack.enter_context(self.wrapper_call.indent())
|
|
if config.profile_bandwidth:
|
|
self.wrapper_call.writeline("start_graph()")
|
|
|
|
while (
|
|
self.lines
|
|
and isinstance(self.lines[-1], MemoryPlanningLine)
|
|
and self.lines[-1].node.name not in out_names
|
|
):
|
|
# these lines will be pointless
|
|
self.lines.pop()
|
|
|
|
# codegen allocations in two passes
|
|
planning_state = MemoryPlanningState()
|
|
for i in range(len(self.lines)):
|
|
if isinstance(self.lines[i], MemoryPlanningLine):
|
|
self.lines[i] = self.lines[i].plan(planning_state)
|
|
|
|
device_cm_stack = contextlib.ExitStack()
|
|
for line in self.lines:
|
|
if isinstance(line, MemoryPlanningLine):
|
|
line.codegen(self.wrapper_call)
|
|
elif isinstance(line, EnterCudaDeviceContextManagerLine):
|
|
line.codegen(self.wrapper_call)
|
|
device_cm_stack.enter_context(self.wrapper_call.indent())
|
|
self.wrapper_call.writeline(
|
|
f"torch.cuda.set_device({line.device_idx}) # no-op to ensure context"
|
|
)
|
|
elif isinstance(line, ExitCudaDeviceContextManagerLine):
|
|
device_cm_stack.close()
|
|
else:
|
|
self.wrapper_call.writeline(line)
|
|
|
|
output_refs = self.get_output_refs()
|
|
if config.triton.debug_sync_graph:
|
|
self.wrapper_call.writeline("torch.cuda.synchronize()")
|
|
|
|
if config.profile_bandwidth:
|
|
self.wrapper_call.writeline("end_graph()")
|
|
|
|
self.generate_return(output_refs)
|
|
|
|
self.append_precomputed_sizes_to_prefix()
|
|
result.splice(self.prefix)
|
|
|
|
with result.indent():
|
|
result.splice(self.wrapper_call)
|
|
|
|
self.generate_end(result)
|
|
|
|
self.add_benchmark_harness(result)
|
|
|
|
return result.getvalue()
|
|
|
|
def add_benchmark_harness(self, output):
|
|
"""
|
|
Append a benchmark harness to generated code for debugging
|
|
"""
|
|
if not config.benchmark_harness:
|
|
return
|
|
|
|
def add_fake_input(name, shape, stride, device, dtype):
|
|
output.writeline(
|
|
f"{name} = rand_strided("
|
|
f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, "
|
|
f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, "
|
|
f"device='{device}', dtype={dtype})"
|
|
)
|
|
|
|
def add_expr_input(name, val):
|
|
output.writeline(f"{name} = {val}")
|
|
|
|
output.writelines(["", "", 'if __name__ == "__main__":'])
|
|
with output.indent():
|
|
output.splice(
|
|
"""
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._inductor.utils import print_performance
|
|
""",
|
|
strip=True,
|
|
)
|
|
|
|
for name, value in V.graph.constants.items():
|
|
add_fake_input(
|
|
name, value.size(), value.stride(), value.device, value.dtype
|
|
)
|
|
|
|
for name, value in V.graph.graph_inputs.items():
|
|
if isinstance(value, sympy.Expr): # Don't need to add symbolic
|
|
add_expr_input(name, V.graph.sizevars.size_hint(value))
|
|
else:
|
|
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
|
|
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
|
|
add_fake_input(
|
|
name, shape, stride, value.get_device(), value.get_dtype()
|
|
)
|
|
|
|
output.writeline(
|
|
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
|
|
)
|
|
|
|
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
|
|
kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else ""
|
|
self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}")
|
|
|
|
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
|
|
return
|
|
|
|
def wrap_kernel_call(self, name, call_args):
|
|
return "{}({})".format(name, ", ".join(call_args))
|
|
|
|
def generate_kernel_call(self, name, call_args):
|
|
self.writeline(
|
|
self.wrap_kernel_call(name, call_args),
|
|
)
|
|
|
|
def call_kernel(self, name: str, kernel: Kernel):
|
|
tmp = IndentedBuffer()
|
|
kernel.call_kernel(self, tmp, name)
|
|
for line in tmp.getvalue().split("\n"):
|
|
line = line.strip()
|
|
if line:
|
|
self.writeline(line)
|
|
|
|
def writeline(self, line):
|
|
self.lines.append(line)
|
|
|
|
|
|
class CppWrapperCodeGen(WrapperCodeGen):
|
|
"""
|
|
The outer wrapper that calls the kernels.
|
|
"""
|
|
|
|
call_func_id = count()
|
|
|
|
def __init__(self):
|
|
self._call_func_id = next(CppWrapperCodeGen.call_func_id)
|
|
super().__init__()
|
|
|
|
@cache_on_self
|
|
def get_output_refs(self):
|
|
def has_cpp_codegen_func(x):
|
|
return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
|
|
x.cpp_wrapper_codegen_reference
|
|
)
|
|
|
|
return [
|
|
x.cpp_wrapper_codegen_reference()
|
|
if has_cpp_codegen_func(x)
|
|
else x.codegen_reference()
|
|
for x in V.graph.graph_outputs
|
|
]
|
|
|
|
def write_prefix(self):
|
|
self.prefix.splice(
|
|
"""
|
|
async_compile.wait(globals())
|
|
del async_compile
|
|
from torch.utils.cpp_extension import load_inline
|
|
wrapper = (
|
|
'''
|
|
#include <dlfcn.h>
|
|
#include <assert.h>
|
|
|
|
template <typename KernelFunc>
|
|
KernelFunc load_cpp_kernel(const char* so_filename) {
|
|
KernelFunc kernel_cpp;
|
|
auto kernel_cpp_lib = dlopen(so_filename, RTLD_NOW);
|
|
assert(kernel_cpp_lib != nullptr);
|
|
*(void **) (&kernel_cpp) = dlsym(kernel_cpp_lib, "kernel");
|
|
return kernel_cpp;
|
|
}
|
|
|
|
"""
|
|
)
|
|
with self.wrapper_call.indent():
|
|
inputs_len = len(V.graph.graph_inputs.keys())
|
|
output_refs = self.get_output_refs()
|
|
if output_refs:
|
|
if len(output_refs) == 1:
|
|
output_types = "at::Tensor"
|
|
else:
|
|
output_types = "std::vector<at::Tensor>"
|
|
else:
|
|
output_types = "void"
|
|
|
|
inputs_types = "std::vector<at::Tensor>"
|
|
self.wrapper_call.writeline(
|
|
f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
|
|
)
|
|
if inputs_len != 0:
|
|
inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
|
|
self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
|
|
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
|
|
self.wrapper_call.writeline(f"{input_key} = args[{idx}];")
|
|
|
|
for name in V.graph.randomness_seeds:
|
|
self.wrapper_call.writeline(f"at::Tensor {name};")
|
|
self.wrapper_call.writeline(
|
|
f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
|
|
)
|
|
V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs)
|
|
|
|
def write_allocate_line(self, buffer):
|
|
self.writeline(CppAllocateLine(buffer))
|
|
|
|
def write_del_line(self, name):
|
|
self.writeline(f"{name}.reset();")
|
|
return
|
|
|
|
def write_free_if_not_reused_line(self, buffer):
|
|
self.writeline(CppFreeIfNotReusedLine(buffer))
|
|
return
|
|
|
|
def write_reuse_line(self, input_buffer, output_buffer):
|
|
self.writeline(CppReuseLine(input_buffer, output_buffer))
|
|
|
|
def get_deferred_line(self, name, layout):
|
|
return DeferredLine(
|
|
name, f"auto {name} = {layout.view.codegen_reference()}; // alias"
|
|
)
|
|
|
|
def get_kernel_path(self, code):
|
|
from ..codecache import pick_vec_isa
|
|
|
|
picked_vec_isa = pick_vec_isa()
|
|
ext = "so"
|
|
extra = code_hash(repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa)))
|
|
# \n is required to match with the CodeCache behavior
|
|
# For reductions, the code string gotten from code.getvalue() will use backslash '\'
|
|
# at the end of lines for readability purpose:
|
|
# #pragma omp declare reduction(xxx :\
|
|
# omp_out.value = xxx,\
|
|
# While the code string loaded during the execution will escape the backslash '\':
|
|
# #pragma omp declare reduction(xxx : omp_out.value = xxx,
|
|
# Use code.getrawvalue() here to escape the backslash to
|
|
# make sure the same code string is used during compilation and execution,
|
|
# so that the hash value is the same.
|
|
source_code = "\n" + code.getrawvalue()
|
|
_, _, kernel_path = get_code_path(source_code, ext, extra)
|
|
return kernel_path
|
|
|
|
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
|
|
kernel_path = self.get_kernel_path(kernel)
|
|
self.writeline(
|
|
f'static auto {name} = load_cpp_kernel<void (*)({arg_types})>("{kernel_path}");'
|
|
)
|
|
|
|
def wrap_kernel_call(self, name, call_args):
|
|
return "{}({});".format(name, ", ".join(call_args))
|
|
|
|
def generate_return(self, output_refs):
|
|
if output_refs:
|
|
if len(output_refs) == 1:
|
|
self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
|
|
else:
|
|
self.wrapper_call.writeline(
|
|
"return std::vector<at::Tensor>({"
|
|
+ ", ".join(output_refs)
|
|
+ "}); }''' )"
|
|
)
|
|
else:
|
|
self.wrapper_call.writeline("return; }''' )")
|
|
|
|
def generate_end(self, result):
|
|
shared = codecache.get_shared()
|
|
warning_all_flag = codecache.get_warning_all_flag()
|
|
cpp_flags = codecache.cpp_flags()
|
|
ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths()
|
|
optimization_flags = codecache.optimization_flags()
|
|
use_custom_generated_macros = codecache.use_custom_generated_macros()
|
|
|
|
extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
|
|
extra_ldflags = f"{shared} {lpaths} {libs}"
|
|
extra_include_paths = f"{ipaths}"
|
|
|
|
# get the hash of the wrapper code to name the extension
|
|
wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
|
|
result.splice(
|
|
f"""
|
|
module = load_inline(
|
|
name='inline_extension_{wrapper_call_hash}',
|
|
cpp_sources=[wrapper],
|
|
functions=['call_{self._call_func_id}'],
|
|
extra_cflags=['{extra_cflags}'],
|
|
extra_ldflags=['{extra_ldflags}'],
|
|
extra_include_paths=['{extra_include_paths}'])
|
|
"""
|
|
)
|
|
# Wrap the func to support setting result._boxed_call = True
|
|
result.splice(
|
|
f"""
|
|
def _wrap_func(f):
|
|
def g(args):
|
|
return f(args)
|
|
return g
|
|
call = _wrap_func(module.call_{self._call_func_id})
|
|
"""
|
|
)
|
|
|
|
def generate_extern_kernel_out(
|
|
self, output_view, codegen_reference, args, kernel, cpp_kernel
|
|
):
|
|
if output_view:
|
|
output_as_strided = f"{output_view.codegen_reference()}"
|
|
output_name = f"{output_view.get_name()}_as_strided"
|
|
self.writeline(f"auto {output_name} = {output_as_strided};")
|
|
|
|
args.insert(0, output_name)
|
|
else:
|
|
args.insert(0, f"{codegen_reference}")
|
|
self.writeline(f"{cpp_kernel}({', '.join(args)});")
|