pytorch/torch/_inductor/codegen/wrapper.py

771 lines
26 KiB
Python

import collections
import contextlib
import dataclasses
import functools
import hashlib
from itertools import count
from typing import Any, Dict, List
from .. import codecache, config, ir
from ..codecache import cpp_compile_command, get_code_path
from ..utils import cache_on_self, dynamo_utils, has_triton, sympy_dot, sympy_product
from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel
from .triton import texpr
pexpr = texpr
def buffer_reuse_key(node: ir.Buffer):
size = node.get_size()
stride = node.get_stride()
last_element = sympy_dot([s - 1 for s in size], stride)
return (
node.get_device(),
node.get_dtype(),
V.graph.sizevars.simplify(sympy_product(size)),
# Detect gaps in tensor storage caused by strides
V.graph.sizevars.size_hint(last_element),
)
def make_buffer_reuse(old, new, del_func, declare, ending, as_strided):
assert old.get_dtype() == new.get_dtype()
del_line = ""
if old.get_name() not in V.graph.get_output_names():
del_line = del_func(old.get_name())
if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}"
return (
f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, "
f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, "
f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}"
)
def make_buffer_allocation(buffer):
device = buffer.get_device()
dtype = buffer.get_dtype()
shape = tuple(buffer.get_size())
stride = tuple(buffer.get_stride())
return (
f"{buffer.get_name()} = empty_strided("
f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
f"device='{device.type}', dtype={dtype})"
)
def make_cpp_buffer_allocation(buffer):
from .cpp import DTYPE_TO_ATEN
# TODO: map layout and device here
dtype = buffer.get_dtype()
shape = tuple(buffer.get_size())
stride = tuple(buffer.get_stride())
return (
f"auto {buffer.get_name()} = at::empty_strided("
f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
f"{DTYPE_TO_ATEN[dtype]}); "
)
class MemoryPlanningState:
def __init__(self):
super().__init__()
self.reuse_pool: Dict[
Any, List["FreeIfNotReusedLine"]
] = collections.defaultdict(list)
def __contains__(self, key):
return bool(self.reuse_pool.get(key, None))
def pop(self, key) -> "FreeIfNotReusedLine":
item = self.reuse_pool[key].pop()
assert not item.is_reused
return item
def push(self, key, item: "FreeIfNotReusedLine"):
assert not item.is_reused
self.reuse_pool[key].append(item)
@dataclasses.dataclass
class EnterCudaDeviceContextManagerLine:
device_idx: int
def codegen(self, code: IndentedBuffer):
code.writeline(f"with torch.cuda.device({self.device_idx}):")
class ExitCudaDeviceContextManagerLine:
def codegen(self, code: IndentedBuffer):
pass
class MemoryPlanningLine:
def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
"""First pass to find reuse"""
return self
def codegen(self, code: IndentedBuffer):
"""Second pass to output code"""
pass
@dataclasses.dataclass
class AllocateLine(MemoryPlanningLine):
node: ir.Buffer
def plan(self, state: MemoryPlanningState):
if self.node.get_name() in V.graph.removed_buffers:
return NullLine()
# try to reuse a recently freed buffer
key = buffer_reuse_key(self.node)
if key in state:
free_line = state.pop(key)
free_line.is_reused = True
return ReuseLine(free_line.node, self.node)
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
code.writeline(make_buffer_allocation(self.node))
@dataclasses.dataclass
class CppAllocateLine(AllocateLine):
def plan(self, state: MemoryPlanningState):
if self.node.get_name() in V.graph.removed_buffers:
return NullLine()
# try to reuse a recently freed buffer
key = buffer_reuse_key(self.node)
if key in state:
free_line = state.pop(key)
free_line.is_reused = True
return CppReuseLine(free_line.node, self.node)
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
code.writeline(make_cpp_buffer_allocation(self.node))
@dataclasses.dataclass
class FreeIfNotReusedLine(MemoryPlanningLine):
node: ir.Buffer
is_reused: bool = False
def plan(self, state: MemoryPlanningState):
assert not self.is_reused
if self.node.get_name() in V.graph.removed_buffers:
return NullLine()
state.push(buffer_reuse_key(self.node), self)
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
if not self.is_reused:
code.writeline(f"del {self.node.get_name()}")
@dataclasses.dataclass
class CppFreeIfNotReusedLine(FreeIfNotReusedLine):
node: ir.Buffer
is_reused: bool = False
def codegen(self, code: IndentedBuffer):
assert (self.node.get_name()) not in V.graph.removed_buffers
if not self.is_reused:
code.writeline(f"{self.node.get_name()}.reset();")
@dataclasses.dataclass
class ReuseLine(MemoryPlanningLine):
node: ir.Buffer
reused_as: ir.Buffer
def plan(self, state: MemoryPlanningState):
assert self.node.get_name() not in V.graph.removed_buffers
assert self.reused_as.get_name() not in V.graph.removed_buffers
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
assert self.reused_as.get_name() not in V.graph.removed_buffers
code.writeline(
make_buffer_reuse(
self.node,
self.reused_as,
del_func=lambda name: f"; del {name}",
declare="",
ending="",
as_strided="as_strided",
)
+ " # reuse"
)
@dataclasses.dataclass
class CppReuseLine(ReuseLine):
node: ir.Buffer
reused_as: ir.Buffer
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
assert self.reused_as.get_name() not in V.graph.removed_buffers
code.writeline(
make_buffer_reuse(
self.node,
self.reused_as,
del_func=lambda name: f"; {name}.reset()",
declare="auto ",
ending=";",
as_strided="at::as_strided",
)
+ " // reuse"
)
@dataclasses.dataclass
class FreeLine(MemoryPlanningLine):
node: ir.Buffer
def plan(self, state: MemoryPlanningState):
if self.node.get_name() in V.graph.removed_buffers:
return NullLine()
return self
def codegen(self, code: IndentedBuffer):
assert self.node.get_name() not in V.graph.removed_buffers
code.writeline(f"del {self.node.get_name()}")
class NullLine(MemoryPlanningLine):
pass
class WrapperCodeGen(CodeGen):
"""
The outer wrapper that calls the kernels.
"""
def __init__(self):
super().__init__()
self._names_iter = count()
self.header = IndentedBuffer()
self.prefix = IndentedBuffer()
self.wrapper_call = IndentedBuffer()
self.kernels = {}
self.lines = []
self.header.splice(
f"""
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from {codecache.__name__} import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
"""
)
if has_triton():
self.header.splice(
f"""
import triton
import triton.language as tl
from {config.inductor_import}.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
"""
)
if config.triton.convolution != "aten":
self.header.splice(
f"""
from {config.inductor_import}.triton_ops.conv_perf_model import early_config_prune
from {config.inductor_import}.triton_ops.conv_perf_model import estimate_conv_time
from {config.inductor_import}.triton_ops.autotune import conv_heuristics
"""
)
self.write_prefix()
for name, value in V.graph.constants.items():
# include a hash so our code cache gives different constants different files
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
self.header.writeline(f"{name} = None # {hashed}")
self.allocated = set()
self.freed = set()
self.write_get_cuda_stream = functools.lru_cache(None)(
self.write_get_cuda_stream
)
@functools.lru_cache(None)
def add_import_once(line):
self.header.writeline(line)
self.add_import_once = add_import_once
self._metas = {}
def add_meta_once(self, meta):
meta = repr(meta)
if meta not in self._metas:
var = f"meta{len(self._metas)}"
self._metas[meta] = var
self.header.writeline(f"{var} = {meta}")
return self._metas[meta]
@cache_on_self
def get_output_refs(self):
return [x.codegen_reference() for x in V.graph.graph_outputs]
def write_prefix(self):
self.prefix.splice(
"""
async_compile.wait(globals())
del async_compile
def call(args):
"""
)
with self.wrapper_call.indent():
if config.triton.debug_sync_graph:
self.wrapper_call.writeline("torch.cuda.synchronize()")
inp_len = len(V.graph.graph_inputs.keys())
if inp_len != 0:
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
self.wrapper_call.writeline(f"{lhs} = args")
self.wrapper_call.writeline("args.clear()")
for name in V.graph.randomness_seeds:
self.wrapper_call.writeline(
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
)
V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs)
def write_get_cuda_stream(self, index):
name = f"stream{index}"
self.writeline(f"{name} = get_cuda_stream({index})")
return name
def next_kernel_suffix(self):
return f"{next(self._names_iter)}"
def write_allocate_line(self, buffer):
self.writeline(AllocateLine(buffer))
def get_deferred_line(self, name, layout):
return DeferredLine(
name, f"{name} = {layout.view.codegen_reference()} # alias"
)
def codegen_allocation(self, buffer):
name = buffer.get_name()
if name in V.graph.removed_buffers or name in self.allocated:
return
self.allocated.add(name)
if isinstance(
buffer,
(ir.ExternKernelAlloc, ir.MultiOutput),
):
return
layout = buffer.get_layout()
if isinstance(layout, ir.MutationLayout):
return
if isinstance(layout, ir.AliasedLayout):
assert isinstance(layout.view, ir.ReinterpretView)
if not layout.maybe_guard_aligned():
V.graph.unaligned_buffers.add(name)
self.codegen_allocation(layout.view.data)
allocation = self.get_deferred_line(name, layout)
self.writeline(allocation)
return
self.write_allocate_line(buffer)
def write_del_line(self, name):
self.writeline(f"del {name}")
def write_free_if_not_reused_line(self, buffer):
self.writeline(FreeIfNotReusedLine(buffer))
def codegen_free(self, buffer):
name = buffer.get_name()
# can be freed but not reused
if isinstance(buffer, ir.InputBuffer):
self.write_del_line(name)
return
if not self.can_reuse(buffer):
return
self.freed.add(name)
layout = buffer.get_layout()
if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
self.write_del_line(name)
return
self.write_free_if_not_reused_line(buffer)
def can_reuse(self, buffer):
name = buffer.get_name()
if (
name in V.graph.removed_buffers
or name in V.graph.graph_inputs
or name in V.graph.constants
or name in self.freed
):
return False
return True
def write_reuse_line(self, input_buffer, output_buffer):
self.writeline(ReuseLine(input_buffer, output_buffer))
def codegen_inplace_reuse(self, input_buffer, output_buffer):
assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
self.codegen_allocation(input_buffer)
self.freed.add(input_buffer.get_name())
self.allocated.add(output_buffer.get_name())
self.write_reuse_line(input_buffer, output_buffer)
def codegen_cuda_device_guard_enter(self, device_idx):
self.lines.append(EnterCudaDeviceContextManagerLine(device_idx))
def codegen_cuda_device_guard_exit(self):
self.lines.append(ExitCudaDeviceContextManagerLine())
def generate_return(self, output_refs):
if output_refs:
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
else:
self.wrapper_call.writeline("return ()")
def generate_end(self, result):
return
def generate_extern_kernel_out(
self, output_view, codegen_reference, args, kernel, cpp_kernel
):
if output_view:
args.append(f"out={output_view.codegen_reference()}")
else:
args.append(f"out={codegen_reference}")
self.writeline(f"{kernel}({', '.join(args)})")
@dynamo_utils.dynamo_timed
def generate(self):
result = IndentedBuffer()
result.splice(self.header)
result.splice(self.prefix)
out_names = V.graph.get_output_names()
with contextlib.ExitStack() as stack:
stack.enter_context(self.wrapper_call.indent())
if config.profiler_mark_wrapper_call:
self.wrapper_call.writeline(
"from torch.profiler import record_function"
)
self.wrapper_call.writeline(
"with record_function('inductor_wrapper_call'):"
)
stack.enter_context(self.wrapper_call.indent())
while (
self.lines
and isinstance(self.lines[-1], MemoryPlanningLine)
and self.lines[-1].node.name not in out_names
):
# these lines will be pointless
self.lines.pop()
# codegen allocations in two passes
planning_state = MemoryPlanningState()
for i in range(len(self.lines)):
if isinstance(self.lines[i], MemoryPlanningLine):
self.lines[i] = self.lines[i].plan(planning_state)
device_cm_stack = contextlib.ExitStack()
for line in self.lines:
if isinstance(line, MemoryPlanningLine):
line.codegen(self.wrapper_call)
elif isinstance(line, EnterCudaDeviceContextManagerLine):
line.codegen(self.wrapper_call)
device_cm_stack.enter_context(self.wrapper_call.indent())
elif isinstance(line, ExitCudaDeviceContextManagerLine):
device_cm_stack.close()
else:
self.wrapper_call.writeline(line)
output_refs = self.get_output_refs()
if config.triton.debug_sync_graph:
self.wrapper_call.writeline("torch.cuda.synchronize()")
self.generate_return(output_refs)
with result.indent():
result.splice(self.wrapper_call)
self.generate_end(result)
self.add_benchmark_harness(result)
return result.getvalue()
def add_benchmark_harness(self, output):
"""
Append a benchmark harness to generated code for debugging
"""
if not config.benchmark_harness:
return
def add_fake_input(name, shape, stride, device, dtype):
output.writeline(
f"{name} = rand_strided("
f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, "
f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, "
f"device='{device}', dtype={dtype})"
)
output.writelines(["", "", 'if __name__ == "__main__":'])
with output.indent():
output.splice(
f"""
from {config.dynamo_import}.testing import rand_strided
from {config.inductor_import}.utils import print_performance
""",
strip=True,
)
for name, value in V.graph.constants.items():
add_fake_input(
name, value.size(), value.stride(), value.device, value.dtype
)
for name, value in V.graph.graph_inputs.items():
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
name, shape, stride, value.get_device(), value.get_dtype()
)
output.writeline(
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
)
def define_kernel(self, name: str, kernel: str):
self.header.splice(f"\n\n{name} = {kernel}")
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
return
def wrap_kernel_call(self, name, call_args):
return "{}({})".format(name, ", ".join(call_args))
def generate_kernel_call(self, name, call_args):
self.writeline(
self.wrap_kernel_call(name, call_args),
)
def call_kernel(self, name: str, kernel: Kernel):
tmp = IndentedBuffer()
kernel.call_kernel(self, tmp, name)
for line in tmp.getvalue().split("\n"):
line = line.strip()
if line:
self.writeline(line)
def writeline(self, line):
self.lines.append(line)
class CppWrapperCodeGen(WrapperCodeGen):
"""
The outer wrapper that calls the kernels.
"""
call_func_id = count()
def __init__(self):
self._call_func_id = next(CppWrapperCodeGen.call_func_id)
super().__init__()
@cache_on_self
def get_output_refs(self):
def has_cpp_codegen_func(x):
return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
x.cpp_wrapper_codegen_reference
)
return [
x.cpp_wrapper_codegen_reference()
if has_cpp_codegen_func(x)
else x.codegen_reference()
for x in V.graph.graph_outputs
]
def write_prefix(self):
self.prefix.splice(
"""
async_compile.wait(globals())
del async_compile
from torch.utils.cpp_extension import load_inline
wrapper = (
'''
#include <dlfcn.h>
#include <assert.h>
"""
)
with self.wrapper_call.indent():
inputs_len = len(V.graph.graph_inputs.keys())
output_refs = self.get_output_refs()
if output_refs:
if len(output_refs) == 1:
output_types = "at::Tensor"
else:
output_types = "std::vector<at::Tensor>"
else:
output_types = "void"
inputs_types = "std::vector<at::Tensor>"
self.wrapper_call.writeline(
f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
)
if inputs_len != 0:
inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
self.wrapper_call.writeline(f"{input_key} = args[{idx}];")
for name in V.graph.randomness_seeds:
self.wrapper_call.writeline(f"at::Tensor {name};")
self.wrapper_call.writeline(
f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
)
V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs)
def write_allocate_line(self, buffer):
self.writeline(CppAllocateLine(buffer))
def write_del_line(self, name):
self.writeline(f"{name}.reset();")
return
def write_free_if_not_reused_line(self, buffer):
self.writeline(CppFreeIfNotReusedLine(buffer))
return
def write_reuse_line(self, input_buffer, output_buffer):
self.writeline(CppReuseLine(input_buffer, output_buffer))
def get_deferred_line(self, name, layout):
return DeferredLine(
name, f"auto {name} = {layout.view.codegen_reference()}; // alias"
)
def get_kernel_path(self, code):
from ..codecache import pick_vec_isa
picked_vec_isa = pick_vec_isa()
ext = "so"
extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa)
# \n is required to match with the CodeCache behavior
# For reductions, the code string gotten from code.getvalue() will use backslash '\'
# at the end of lines for readability purpose:
# #pragma omp declare reduction(xxx :\
# omp_out.value = xxx,\
# While the code string loaded during the execution will escape the backslash '\':
# #pragma omp declare reduction(xxx : omp_out.value = xxx,
# Use code.getrawvalue() here to escape the backslash to
# make sure the same code string is used during compilation and execution,
# so that the hash value is the same.
source_code = "\n" + code.getrawvalue()
_, _, kernel_path = get_code_path(source_code, ext, extra)
return kernel_path
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
kernel_path = self.get_kernel_path(kernel)
self.writeline(f'auto {name}_lib = dlopen("{kernel_path}", RTLD_NOW);')
self.writeline(f"assert({name}_lib != nullptr);")
self.writeline(f"void (*{name})({arg_types});")
self.writeline(f'*(void **) (&{name}) = dlsym({name}_lib, "kernel");')
def wrap_kernel_call(self, name, call_args):
return "{}({});".format(name, ", ".join(call_args))
def generate_return(self, output_refs):
if output_refs:
if len(output_refs) == 1:
self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
else:
self.wrapper_call.writeline(
"return std::vector<at::Tensor>({"
+ ", ".join(output_refs)
+ "}); }''' )"
)
else:
self.wrapper_call.writeline("return; }''' )")
def generate_end(self, result):
shared = codecache.get_shared()
warning_all_flag = codecache.get_warning_all_flag()
cpp_flags = codecache.cpp_flags()
ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths()
optimization_flags = codecache.optimization_flags()
use_custom_generated_macros = codecache.use_custom_generated_macros()
extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
extra_ldflags = f"{shared} {lpaths} {libs}"
extra_include_paths = f"{ipaths}"
# get the hash of the wrapper code to name the extension
wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
result.splice(
f"""
module = load_inline(
name='inline_extension_{wrapper_call_hash}',
cpp_sources=[wrapper],
functions=['call_{self._call_func_id}'],
extra_cflags=['{extra_cflags}'],
extra_ldflags=['{extra_ldflags}'],
extra_include_paths=['{extra_include_paths}'])
"""
)
# Wrap the func to support setting result._boxed_call = True
result.splice(
f"""
def _wrap_func(f):
def g(args):
return f(args)
return g
call = _wrap_func(module.call_{self._call_func_id})
"""
)
def generate_extern_kernel_out(
self, output_view, codegen_reference, args, kernel, cpp_kernel
):
if output_view:
output_as_strided = f"{output_view.codegen_reference()}"
output_name = f"{output_view.get_name()}_as_strided"
self.writeline(f"auto {output_name} = {output_as_strided};")
args.insert(0, output_name)
else:
args.insert(0, f"{codegen_reference}")
self.writeline(f"{cpp_kernel}({', '.join(args)});")