pytorch/torch/_inductor/codegen/common.py
Jez Ng f64a97c6f8 [inductor] Memory planning (#112178)
This was originally @jansel's PR:
https://github.com/pytorch/pytorch/pull/102625, which I've built upon.

This diff implements static memory planning. It's disabled by default
while we examine its performance.

We use a greedy-by-size approach. For dynamic shapes, the sizes of the
example inputs are used as estimates when making planning decisions. We
generate expressions to calculate the actual memory offsets and sizes at
runtime when the values of the dynamic shapes are known. In order to
simplify these calculations, we have organized the allocations into a
tree that branches on space (address offsets) and time (live ranges).
Finally, we need to align these offsets, so we have added an `align`
sympy Expr to express these calculations.

Some limitations:

1. It is only enabled during inference for now. Enabling it for training
   increases peak memory usage as we allocate all the memory needed for
   training upfront, before freeing the memory allocated during
   inference. We can probably address this by doing planning for both
   the inference and training passes together.
2. It doesn't work with PyTorch Distributed, because kernels like
   AllGatherIntoTensor codegen strings which do memory operations. We
   can fix this down the line by having them emit MemoryPlanningLines
   instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112178
Approved by: https://github.com/desertfire, https://github.com/jansel
2023-10-31 20:02:30 +00:00

1273 lines
42 KiB
Python

import contextlib
import dataclasses
import functools
import itertools
import logging
import operator
import re
from collections import namedtuple
from itertools import chain
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Union,
)
import sympy
from sympy.printing.printer import Printer
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from .. import config, metrics
from ..utils import (
DeferredLineBase,
do_bench,
free_symbol_startswith,
IndentedBuffer,
sympy_dot,
sympy_subs,
sympy_symbol,
unique,
)
from ..virtualized import ops, OpsValue, V
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
def data_type_logger(msg):
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Data type propagation: %s", msg)
TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
SizeArg = namedtuple("SizeArg", ["name", "expr"])
DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"])
device_codegens: Dict[str, DeviceCodegen] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
device: str, device_scheduling: type, device_wrapper_codegen: type
):
device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)
def get_scheduling_for_device(device: str):
return device_codegens[device].scheduling if device in device_codegens else None
def get_wrapper_codegen_for_device(device: str):
return (
device_codegens[device].wrapper_codegen if device in device_codegens else None
)
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
from ..ir import FlexibleLayout
# added contiguous index prevents reordering
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
@functools.lru_cache(None)
def boolean_ops():
return (
"is_inf",
"is_nan",
"bitwise_xor",
"logical_not",
"signbit",
"le",
"lt",
"ge",
"gt",
"eq",
"ne",
)
DTYPE_TO_COMPUTATION_DTYPE = {
torch.bfloat16: torch.float,
torch.float16: torch.float,
**{
dtype: dtype
for dtype in [
torch.bool,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
]
},
}
class DataTypePropagation:
def __init__(self, body) -> None:
self.body = body
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
"root": body.root_block.graph
}
for k, v in body.subblocks.items():
self.graphs[k] = v.graph
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
inputs = node.all_input_nodes
input_nodes = [
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
]
if len(input_nodes) == 0:
return None
all_input_nodes_propogated = all(
OptimizationContext.key in n.meta
and n.meta[OptimizationContext.key].dtype is not None
for n in input_nodes
)
if not all_input_nodes_propogated:
return None
return functools.reduce(
torch.promote_types,
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
)
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
sub_graph = self.graphs[node.target]
dtype = self.propagate_graph(sub_graph)
assert dtype
return dtype
def deduce_node_dtype(self, node: torch.fx.Node):
if node.target in boolean_ops():
return torch.bool
if node.op == "placeholder":
return None
if node.target == "output":
# we can infer output node if it only have 1 arg
if len(node.args) != 1:
return None
if node.target in (
"to_dtype",
"index_expr",
):
return node.args[-1]
if node.target in (
"rand",
"randn",
):
return torch.float
if node.target in (
"get_index",
"index_expr",
):
return torch.int64
if node.target in (
"load",
"store",
"store_reduction",
):
buf_name = node.args[1]
return V.graph.get_dtype(buf_name)
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0])
assert isinstance(node.target, str)
if node.target == "reduction":
return node.args[1]
if node.target == "constant":
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]]
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
return self.deduce_node_dtype_by_inputs(node)
def propagate_graph(self, graph: torch.fx.Graph):
assert graph.nodes
graph_dtype = None
# For masked_subblock, we use output's dtype to represent
# the dtype of this subgraph. For other cases, graph_dtype
# might be None
for node in graph.nodes:
if OptimizationContext.key in node.meta:
opt_ctx = node.meta[OptimizationContext.key]
else:
opt_ctx = OptimizationContext()
opt_ctx.dtype = self.deduce_node_dtype(node)
node.meta[OptimizationContext.key] = opt_ctx
if node.target == "output":
graph_dtype = opt_ctx.dtype
return graph_dtype
def propagate(self):
self.propagate_graph(self.graphs["root"])
@classmethod
def propagate_loopbody(cls, body):
return cls(body).propagate()
@classmethod
def propagate_scheduler_node(cls, node):
from ..ir import LoopBody
from ..scheduler import SchedulerNode
assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
DataTypePropagation.propagate_loopbody(node._body)
class ExprPrinter(Printer):
@staticmethod
def paren(string):
def all_in_parens(string):
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
if (
isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.I)
or re.match(r"^\([^)]*\)$", string, re.I)
or string == ""
):
return string
# don't put extra parens for strings that are already wrapped in parens
if all_in_parens(string):
return string
return f"({string})"
def _print_Pow(self, expr):
# Pow() confuses triton
base, exp = expr.args
# NB: Remember this is sizevar computation! You don't typically
# expect to have to do floating point computation including exponents
# in sizevar compute. Instead of adding support for floating
# point pow, you should make upstream retranslate the Sympy expression
# into Tensor expressions earlier and do that instead.
if exp == 0.5:
return self._helper_sqrt(base) # type: ignore[attr-defined]
elif exp == -0.5:
return "1/" + self._helper_sqrt(base) # type: ignore[attr-defined]
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
if exp > 0:
return "*".join([self.paren(base)] * exp)
elif exp < 0:
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
return "1"
def _print_Unequality(self, expr):
return " != ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mul(self, expr):
return "*".join(map(self.paren, map(self._print, expr.args)))
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr) # type: ignore[attr-defined]
def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
class PythonPrinter(ExprPrinter):
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
def __getattr__(self, item):
return getattr(self._parent, item)
@staticmethod
def identity(value):
# used to trigger cse
return value
@staticmethod
def constant(value, dtype):
return repr(value)
@staticmethod
def reciprocal(x):
return ops.div("1", x)
@staticmethod
def square(x):
return ops.mul(x, x)
@staticmethod
def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}"
@staticmethod
def logical_not(a):
return f"{ExprPrinter.paren(a)} == 0"
@staticmethod
def bitwise_and(x, y):
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_or(x, y):
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_xor(x, y):
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_left_shift(x, y):
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
# TODO(fdrocha): this is currently not being used anywhere,
# pending on moving triton pin past 972b761
@staticmethod
def bitwise_right_shift(x, y):
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
@staticmethod
def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))
class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
def __init__(self, name, line):
super().__init__(line)
self.name = name
def __call__(self):
if all(
self.name not in x
for x in (
V.graph.removed_buffers,
V.kernel.removed_buffers,
V.graph.inplaced_to_remove,
V.kernel.inplaced_to_remove,
)
):
return self.line
return None
def _new_line(self, line):
return DeferredLine(self.name, line)
class BracesBuffer(IndentedBuffer):
def indent(self, offset=1):
@contextlib.contextmanager
def ctx():
for _ in range(offset):
self.writeline("{")
self._indent += 1
for _ in range(-offset):
self._indent -= 1
self.writeline("}")
yield
for _ in range(-offset):
self.writeline("{")
self._indent += 1
for _ in range(offset):
self._indent -= 1
self.writeline("}")
return ctx()
class InplacedBuffer(NamedTuple):
inner_name: str
other_names: List[str]
class KernelArgs:
@staticmethod
def _lookup(prefix, odict, name):
assert isinstance(name, (str, sympy.Symbol))
if name not in odict:
odict[name] = f"{prefix}{len(odict)}"
return odict[name]
def __init__(self, sizevars=None):
self.input_buffers = dict()
self.output_buffers = dict()
self.inplace_buffers = dict()
self.sizevars = sizevars or dict()
def __repr__(self):
return "KernelArgs({})".format(
", ".join(
map(
repr,
[
self.input_buffers,
self.output_buffers,
self.inplace_buffers,
self.sizevars,
],
)
)
)
def _buffer_is_marked_removed(self, name):
return isinstance(name, str) and name.startswith("REMOVED")
def input(self, name):
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.output_buffers:
return self.output_buffers[name]
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
if name.startswith("seed"):
return self._lookup("seed", self.input_buffers, name)
return self._lookup("in_ptr", self.input_buffers, name)
def output(self, name):
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
return self._lookup("out_ptr", self.output_buffers, name)
def make_inplace(self, input_name, output_name):
assert output_name not in self.inplace_buffers
if input_name in self.inplace_buffers:
buf = self.inplace_buffers[input_name]
buf.other_names.append(output_name)
self.inplace_buffers[output_name] = buf
else:
buf = InplacedBuffer(
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
[input_name, output_name],
)
self.inplace_buffers[input_name] = buf
self.inplace_buffers[output_name] = buf
def seed_offset(self, name, value):
if value in self.sizevars:
return self.sizevars[value]
if name in self.sizevars.values():
name = (
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
)
self.sizevars[value] = name
return name
def size(self, name):
if str(name) == "seed":
self.sizevars["seed"] = "seed"
return "seed"
return self._lookup("ks", self.sizevars, name)
def call_names(self):
return chain(
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
)
def wrap_ptr_arg(self, buf, dtype):
return f"c_void_p({buf}.data_ptr())"
def wrap_size_arg(self, size):
return f"c_long({size})"
def cpp_argdefs(self):
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
call_args = []
arg_defs = []
arg_types = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
outer = inplaced.other_names[-1]
inner = inplaced.inner_name
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.input_buffers.items():
if outer in self.inplace_buffers:
continue
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"const {cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"const {cpp_dtype}*")
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
arg_defs.append(f"const {INDEX_TYPE} {inner}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
return arg_defs, call_args, arg_types
def python_argdefs(self):
arg_defs = []
call_args = []
precompile_args: List[Union[TensorArg, SizeArg]] = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
arg_defs.append(inplaced.inner_name)
call_args.append(inplaced.other_names[-1])
precompile_args.append(
TensorArg(
inplaced.inner_name,
inplaced.other_names[-1],
V.graph.get_dtype(inplaced.other_names[-1]),
)
)
for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items()
):
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(SizeArg(inner, outer))
return arg_defs, call_args, precompile_args
def aliases(self):
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
for other in inplaced.other_names:
if (
other in V.graph.inplaced_to_remove
or other in V.kernel.inplaced_to_remove
):
continue
if other in self.input_buffers:
yield self.input_buffers[other], inplaced.inner_name
if other in self.output_buffers:
yield self.output_buffers[other], inplaced.inner_name
def is_removed(self, name):
def _is_removed(name, buffers):
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
return _is_removed(name, self.output_buffers) and _is_removed(
name, self.inplace_buffers
)
# Includes inplace buffers, excludes removed buffers. Essentially,
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self):
live_outs = set()
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
live_outs.add(inplaced.other_names[-1])
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
live_outs.add(outer)
return live_outs
class CSEVariable:
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
To do so, the backends can simply overload `Kernel.create_cse_var`
The "CSEVariable.update_on_args" method gives you a hook for annotations
See example of TritonCSEVariable in triton.py
"""
def __init__(self, name, bounds: ValueRanges):
assert isinstance(bounds, ValueRanges)
self.name = name
self.bounds = bounds
def __str__(self):
return self.name
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other) -> bool:
return type(other) == type(self) and other.name == self.name
def update_on_args(self, name, args, kwargs):
pass
class CppWrapperKernelArgs(KernelArgs):
def wrap_ptr_arg(self, buf, dtype):
from .cpp import DTYPE_TO_CPP
if config.aot_inductor.abi_compatible:
# In the abi_compatible model, we just return the buf here.
# We will form correct call args later in wrapper.generate_kernel_all.
return buf
else:
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
def wrap_size_arg(self, size):
return f"{size}"
class CSE:
"""Common subexpression elimination"""
def __init__(
self,
prefix="",
suffix="",
name_prefix="tmp",
iter_buffers=None,
store_cache=None,
reduction_cache=None,
varname_map=None,
):
self.prefix = prefix
self.suffix = suffix
self.cache = {}
self.name_prefix = name_prefix
self.store_cache = store_cache or {}
self.reduction_cache = reduction_cache or {}
self.iter_buffer_ids = iter_buffers or itertools.count()
self.invalidated_stores = set()
self.varname_map = varname_map or {}
def invalidate(self, keep_vars: Set[str]):
for name, tmp in list(self.store_cache.items()):
if tmp not in keep_vars:
del self.store_cache[name]
self.invalidated_stores.add(name)
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
def clone(self):
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
return CSE(
prefix=self.prefix,
suffix=self.suffix,
name_prefix=self.name_prefix,
iter_buffers=self.iter_buffer_ids,
store_cache=self.store_cache,
varname_map=self.varname_map,
)
def generate(
self,
buffer: IndentedBuffer,
expr: Union[str, CSEVariable, OpsValue],
*,
bounds: ValueRanges = ValueRanges.unknown(),
write=True,
assignment=True,
) -> CSEVariable:
if isinstance(expr, OpsValue):
expr = expr.value
assert isinstance(expr, (str, CSEVariable)), type(expr)
assert write or assignment
if isinstance(expr, CSEVariable):
# If the expressions were always created with all the information, we could
# assert expr.bounds == bounds, but sometimes the expression is created
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
expr.bounds = expr.bounds.tighten(bounds)
return expr
cache_key = expr
var = self.cache.get(cache_key, None)
if not var:
var = self.newvar(bounds) if assignment else None
self.cache[cache_key] = var
if write:
if V.kernel.current_node:
V.kernel.current_node.codegen_originating_info(
buffer, only_once=True
)
if assignment:
line = f"{self.prefix}{var} = {expr}{self.suffix}"
else:
line = f"{expr}{self.suffix}"
buffer.writeline(line)
else:
var.bounds = var.bounds.tighten(bounds)
return var
def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable:
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
var = V.kernel.create_cse_var(var_name, bounds)
self.varname_map[var_name] = var
return var
class IndirectAssertLine(DeferredLineBase):
def __init__(self, line, assert_fn, var, mask, size_map):
self.var = var
self.mask = mask
self.line = line
self.assert_fn = assert_fn
self.size_map = size_map
def __call__(self):
size, size_str = self.size_map[(self.var, self.mask)]
# We assert if we've not been able to prove the bound
assert_min = (self.var.bounds.lower >= 0) != sympy.true
assert_max = (self.var.bounds.upper < size) != sympy.true
# FooBar interview question
if not (assert_min or assert_max):
return None
elif assert_min and assert_max:
# The conditions need to be in parens because of Python's operator precedence.
# It'd be less error-prone to use and/or/not, which is suported by triton
cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
cond_print = f"0 <= {self.var} < {size_str}"
elif assert_min:
cond = f"0 <= {self.var}"
cond_print = cond
else:
assert assert_max
cond = f"{self.var} < {size_str}"
cond_print = cond
if self.mask:
cond = f"({cond}) | ~{self.mask}"
return self.line.format(
assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
)
def _new_line(self, line):
return IndirectAssertLine(
line, self.assert_fn, self.var, self.mask, self.size_map
)
class CodeGen:
def __init__(self):
super().__init__()
self.exit_stack = contextlib.ExitStack()
def __enter__(self):
self.exit_stack.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
class Kernel(CodeGen):
newvar_prefix = ""
suffix = ""
overrides = None
load_format = None
store_format = None
def __init__(self, args=None, increase_kernel_count=True):
super().__init__()
if increase_kernel_count:
metrics.generated_kernel_count += 1
self.args = args or KernelArgs()
self.loads = IndentedBuffer()
self.compute = IndentedBuffer()
self.stores = IndentedBuffer()
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
self.must_keep_buffers = set()
self.store_buffer_names = set()
self._load_mask = None
# set in set_current_node
self.current_node = None
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None
# Upper bounds for indirect_indexing and their str representation
self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {}
self.removed_buffers = set()
self.inplaced_to_remove = set()
# key: the buffer to write
# value: the buffer to read and whose memory can be reused for
# the buffer specified by key
self.inplace_update_buffers = dict()
# Set minimum number of elements processed per thread.
self.min_elem_per_thread = 1
@contextlib.contextmanager
def set_current_node(self, node):
prior = self.current_node
self.current_node = node
self.node_to_bounds = node._body.bounds().get_bounds()
try:
yield
finally:
self.current_node = prior
@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
if cb is None:
cb = lb
loads = self.loads
compute = self.compute
stores = self.stores
cse = self.cse
self.loads = lb
self.compute = cb
self.stores = sb
self.cse = cse.clone()
try:
yield
finally:
self.loads = loads
self.compute = compute
self.stores = stores
self.cse = cse
def load(self, name: str, index: sympy.Expr):
raise NotImplementedError()
def indirect_load(self, name: str, index: sympy.Expr):
"""A load the depends on an index we have read"""
prior = self.loads
try:
# put the load in the compute section as it might have deps
self.loads = self.compute
return self.load(name, index)
finally:
self.loads = prior
def store_reduction(self, name, index, value):
raise NotImplementedError()
def store(self, name, index, value, mode=None):
raise NotImplementedError()
def reduction(self, dtype, src_dtype, reduction_type, value):
raise NotImplementedError()
def bucketize(
self,
values,
offsets_name: str,
offsets_size: sympy.Expr,
indexing_dtype: torch.dtype,
right: bool,
):
"""
See [Note: Inductor bucketize op]
"""
raise NotImplementedError()
def __enter__(self):
class CSEProxy:
self.name = "CSEProxy"
@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
# TritonTemplateKernel has no current_node
buf_bounds = ValueRanges.unknown()
if hasattr(V.interpreter, "current_node"):
fx_node = V.interpreter.current_node
assert isinstance(self.node_to_bounds, dict)
buf_bounds = self.node_to_bounds.get(
fx_node, ValueRanges.unknown()
)
csevar = self.cse.generate(
self.compute,
getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type]
bounds=buf_bounds,
)
csevar.update_on_args(name, args, kwargs)
return csevar
return inner
@staticmethod
def indirect_indexing(var, size, check=True):
# Skip CSE since this doesn't return an expression
if var.bounds.lower < 0:
new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance(
size, sympy.Number
):
# Take the negative part of the bound and add size to it
# Then take union of that and the positive part
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
neg = var.bounds & ValueRanges(-sympy.oo, -1)
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0:
pos = var.bounds & ValueRanges(0, sympy.oo)
new_bounds = new_bounds | pos
stm = ops.add(var, self.rename_indexing(size))
# Mixed negative and non-negative
if var.bounds.upper >= 0:
lt = ops.lt(var, "0")
stm = ops.where(lt, stm, var)
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
new_var.update_on_args("index_wrap", (var,), {})
var = new_var
if self.generate_assert(check):
mask = self.load_mask(var)
# An assertion line may have been written already, if so just
# update the max size.
map_key = (var, mask)
existing_size, _ = self.indirect_max_sizes.get(
map_key, (None, None)
)
if existing_size is not None:
size = sympy.Min(size, existing_size)
else:
line = (
'{assert_fn}({cond}, "index out of bounds: {cond_print}")'
)
self.compute.writeline(
IndirectAssertLine(
line,
self.assert_function, # type: ignore[attr-defined]
var,
mask,
self.indirect_max_sizes,
)
)
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) # type: ignore[attr-defined]
return sympy_symbol(str(var))
@staticmethod
def load(name: str, index: sympy.Expr):
if name in self.cse.invalidated_stores:
# A load from an invalidated store requires us to
# keep the actual buffer around
V.kernel.must_keep_buffers.add(name)
if free_symbol_startswith(index, "tmp"):
return self.indirect_load(name, index)
store_cache = self.cse.store_cache
if name in store_cache:
return store_cache[name]
return self.load(name, index)
@staticmethod
def store(name, index, value, mode=None):
self.store_buffer_names.add(name)
if mode is None:
self.cse.store_cache[name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store(name, index, value, mode=mode)
@staticmethod
def store_reduction(name, index, value):
self.store_buffer_names.add(name)
self.cse.store_cache[name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store_reduction(name, index, value)
@staticmethod
def reduction(dtype, src_dtype, reduction_type, value):
return self.reduction(dtype, src_dtype, reduction_type, value)
@staticmethod
def bucketize(
values,
offsets_name: str,
offsets_size: sympy.Expr,
indexing_dtype: torch.dtype,
right: bool,
):
"""
[Note: Inductor bucketize op]
Given values (tensor) and offsets_name (reference to the name of a 1D
tensor), calculate the bucket that each value belongs to.
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
Offsets must be non-decreasing or the result is undefined.
"""
return self.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
)
super().__enter__()
assert self.overrides
parent_handler = self.overrides(V.get_ops_handler())
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Note that V.graph.scheduler can be None when codegening triton template
kernels.
"""
if V.graph.scheduler:
V.graph.scheduler.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)
def generate_assert(self, check):
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
def load_mask(self, var):
# only the triton kernel requires mask
return ""
def rename_indexing(self, index) -> sympy.Expr:
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index]
index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = {
x: self.args.size(x)
for x in sorted_symbols
if x.name.startswith("s")
or x.name.startswith("ps")
or (x.name.startswith("i") and not x.name.startswith("idx"))
}
return sympy_subs(index, replacements)
def create_cse_var(self, *args, **kwargs):
return CSEVariable(*args, **kwargs)
@dataclasses.dataclass
class OptimizationContext:
key: ClassVar[str] = "opt_ctx"
# Load value as mask
is_load_as_mask: bool = False
dtype: torch.dtype = None
ops_name: str = ""
is_most_inner_loop_irrevelant: bool = False
# Load uint8 value as float32
is_load_uint8_as_float: bool = False
@functools.lru_cache(None)
def jinja2_env():
try:
import jinja2
return jinja2.Environment(
undefined=jinja2.StrictUndefined,
)
except ImportError:
return None
class ChoiceCaller:
"""
Represents a possible choice used in autotune_process.py.
During autotuning, self.benchmark() is first called to get benchmark result,
and if this choice is selected, self.output_node() is called to get the output_node.
Children classes: TritonTemplateCaller, CUDATemplateCaller.
"""
def __init__(self, name, input_nodes, layout):
super().__init__()
self.name = name
self.layout = layout
self.input_nodes = input_nodes
def benchmark(self, *args, out) -> float:
algo = self.to_callable()
return do_bench(lambda: algo(*args, out=out))
def call_name(self) -> str:
raise NotImplementedError()
def to_callable(self):
raise NotImplementedError()
def hash_key(self) -> str:
raise NotImplementedError()
def output_node(self) -> "TensorBox": # type: ignore[name-defined]
raise NotImplementedError()
class KernelTemplate:
"""
Base class for defining kernel templates.
Children classes: TritonTemplate, CUDATemplate
"""
@staticmethod
def _template_from_string(source):
env = jinja2_env()
if env is not None:
return env.from_string(source)
return None
@staticmethod
def _fake_get_dtype(fake_out):
_get_dtype_real = V.graph.get_dtype
def get_dtype(name):
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
def __init__(self, name: str):
self.name = name
def maybe_append_choice(self, choices, **kwargs):
"""
Maybe generates a new ChoiceCaller and appends it into existing choices.
choices: A list of ChoiceCallers.
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
"""
try:
choices.append(self.generate(**kwargs))
except NotImplementedError:
pass
def generate(self, **kwargs) -> ChoiceCaller:
"""
Generates a ChoiceCaller instance from the given arguments.
"""
raise NotImplementedError()