mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
When running TIMM ```convit_base``` dynamic shape case, there is always has AssertionError, see https://github.com/pytorch/pytorch/issues/97877. A simple reproduce code is: ``` import torch import torch._dynamo import torch._dynamo.config as config config.dynamic_shapes=True torch._dynamo.config.assume_static_by_default=False class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): B, N, C = x.shape return self.get_rel_indices(N) def get_rel_indices(self, num_patches: int) -> torch.Tensor: img_size = int(num_patches ** .5) #rel_indices = torch.zeros(1, num_patches, num_patches, 3) ind = torch.arange(img_size) return ind model = Model().eval() opt_model = torch._dynamo.optimize('inductor')(model) x = torch.randn(8, 8, 8) ref = model(x) with torch.no_grad(): for i in range(3): out = opt_model(x) ``` After this code, the generated code will be like this: ``` kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_xiaobing/x5/cx5442c6dcuxsrrlnqi476yzjlgc6g53ukppuaettiyp6dszhmr4.h" extern "C" void kernel(long* out_ptr0, const long ks0) { { #pragma GCC ivdep for(long i0=static_cast<long>(0L); i0<static_cast<long>(std::floor(std::sqrt(ks0))); i0+=static_cast<long>(1L)) { auto tmp0 = static_cast<long>(i0); out_ptr0[static_cast<long>(i0)] = tmp0; } } } ''') ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/99514 Approved by: https://github.com/jansel, https://github.com/jgong5
812 lines
26 KiB
Python
812 lines
26 KiB
Python
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import re
|
|
import typing
|
|
from collections import namedtuple
|
|
from itertools import chain
|
|
|
|
import sympy
|
|
from sympy.printing.printer import Printer
|
|
|
|
import torch
|
|
|
|
from .. import metrics
|
|
from ..utils import (
|
|
DeferredLineBase,
|
|
free_symbol_startswith,
|
|
IndentedBuffer,
|
|
sympy_dot,
|
|
sympy_subs,
|
|
sympy_symbol,
|
|
unique,
|
|
)
|
|
from ..virtualized import ops, V
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def data_type_logger(msg):
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug("Data type propagation: %s", msg)
|
|
|
|
|
|
TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
|
|
SizeArg = namedtuple("SizeArg", ["name", "expr"])
|
|
|
|
|
|
def index_prevent_reordering(index: typing.List[sympy.Expr], index_vars, sizes):
|
|
from ..ir import FlexibleLayout
|
|
|
|
# added contiguous index prevents reordering
|
|
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
|
|
|
|
|
def _data_type_propagation(sub_graph: torch.fx.Graph):
|
|
def propagate_node(node: torch.fx.Node):
|
|
_node: torch.fx.Node = node
|
|
ops_to_bool = [
|
|
"is_inf",
|
|
"is_nan",
|
|
"bitwise_xor",
|
|
"logical_not",
|
|
"signbit",
|
|
"le",
|
|
"lt",
|
|
"ge",
|
|
"gt",
|
|
"eq",
|
|
"ne",
|
|
]
|
|
ops_with_dtype_arg = ["constant", "to_dtype", "rand", "randn"]
|
|
reduction_to_dtype = {
|
|
"any": torch.bool,
|
|
"argmin": torch.int64,
|
|
"argmax": torch.int64,
|
|
}
|
|
ops_without_dtype = ["ops", "get_index"]
|
|
if _node.target in ops_without_dtype:
|
|
return False
|
|
if OptimizationContext.key in _node.meta:
|
|
opt_ctx = _node.meta[OptimizationContext.key]
|
|
else:
|
|
opt_ctx = OptimizationContext()
|
|
if opt_ctx.dtype is not None:
|
|
return False
|
|
if _node.target in ops_to_bool:
|
|
opt_ctx.dtype = torch.bool
|
|
elif _node.target in ops_with_dtype_arg:
|
|
opt_ctx.dtype = _node.args[-1]
|
|
elif _node.target == "reduction":
|
|
reduction_type = _node.args[4]
|
|
if reduction_type in reduction_to_dtype:
|
|
opt_ctx.dtype = reduction_to_dtype[reduction_type]
|
|
elif _node.target == "load":
|
|
opt_ctx.dtype = V.graph.get_dtype(_node.args[1])
|
|
if opt_ctx.dtype is not None:
|
|
data_type_logger(
|
|
f"for node.target = {_node.target}, dtype is propagated to {opt_ctx.dtype}"
|
|
)
|
|
_node.meta[OptimizationContext.key] = opt_ctx
|
|
return True
|
|
|
|
# node.target not belong to any ops which can directly get the dtype
|
|
# need propogate dtype with it's input node
|
|
dtype = None
|
|
inputs = node.all_input_nodes
|
|
input_nodes = [
|
|
n
|
|
for n in inputs
|
|
if isinstance(n, torch.fx.node.Node) and n.target not in ops_without_dtype
|
|
]
|
|
if len(input_nodes) == 0:
|
|
return False
|
|
all_input_nodes_propogated = all(
|
|
OptimizationContext.key in n.meta
|
|
and n.meta[OptimizationContext.key].dtype is not None
|
|
for n in input_nodes
|
|
)
|
|
if not all_input_nodes_propogated:
|
|
return False
|
|
# all input nodes have propogated dtype, we will promot to dtype with highest precision
|
|
dtype = functools.reduce(
|
|
torch.promote_types,
|
|
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
|
)
|
|
opt_ctx.dtype = dtype
|
|
msg = f"for node.target = {_node.target}, dtype is propagated to {opt_ctx.dtype}, "
|
|
input_msg = "inputs dtypes: "
|
|
for n in input_nodes:
|
|
input_msg += (
|
|
f"input {n.name}.dtype = {n.meta[OptimizationContext.key].dtype}"
|
|
)
|
|
data_type_logger(msg + input_msg)
|
|
_node.meta[OptimizationContext.key] = opt_ctx
|
|
return True
|
|
|
|
new_node_propogated = False
|
|
for node in sub_graph.nodes:
|
|
new_node_propogated = propagate_node(node) or new_node_propogated
|
|
|
|
if new_node_propogated:
|
|
_data_type_propagation(sub_graph)
|
|
|
|
|
|
def data_type_propagation(node):
|
|
from ..ir import LoopBody
|
|
from ..scheduler import SchedulerNode
|
|
|
|
assert isinstance(node, SchedulerNode)
|
|
_node: SchedulerNode = node
|
|
if isinstance(_node._body, LoopBody):
|
|
body: LoopBody = node._body
|
|
sub_blocks = [body.root_block] + list(body.subblocks.values())
|
|
for sub_block in sub_blocks:
|
|
_sub_graph: torch.fx.Graph = sub_block.graph
|
|
_data_type_propagation(_sub_graph)
|
|
|
|
|
|
class ExprPrinter(Printer):
|
|
@staticmethod
|
|
def paren(string):
|
|
def all_in_parens(string):
|
|
if string[0] != "(" or len(string) < 2:
|
|
return False
|
|
count = 1
|
|
for i, char in enumerate(string[1:]):
|
|
if char == "(":
|
|
count += 1
|
|
elif char == ")":
|
|
count -= 1
|
|
if count == 0 and i != len(string) - 2:
|
|
return False
|
|
assert count == 0
|
|
return True
|
|
|
|
if (
|
|
isinstance(string, CSEVariable)
|
|
or re.match(r"^[a-z0-9_.]+$", string, re.I)
|
|
or re.match(r"^\([^)]*\)$", string, re.I)
|
|
or string == ""
|
|
):
|
|
return string
|
|
# don't put extra parens for strings that are already wrapped in parens
|
|
if all_in_parens(string):
|
|
return string
|
|
return f"({string})"
|
|
|
|
def _print_Pow(self, expr):
|
|
# Pow() confuses triton
|
|
base, exp = expr.args
|
|
base = self._print(base)
|
|
# NB: Remember this is sizevar computation! You don't typically
|
|
# expect to have to do floating point computation including exponents
|
|
# in sizevar compute. Instead of adding support for floating
|
|
# point pow, you should make upstream retranslate the Sympy expression
|
|
# into Tensor expressions earlier and do that instead.
|
|
if exp == 0.5:
|
|
return f"math.sqrt({base})"
|
|
assert exp.is_integer
|
|
exp = int(exp)
|
|
if exp > 0:
|
|
return "*".join([self.paren(base)] * exp)
|
|
elif exp < 0:
|
|
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
|
else: # exp == 0
|
|
return "1"
|
|
|
|
def _print_Mul(self, expr):
|
|
return "*".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Add(self, expr):
|
|
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Mod(self, expr):
|
|
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_CleanDiv(self, expr):
|
|
return self._print_FloorDiv(expr)
|
|
|
|
|
|
class PythonPrinter(ExprPrinter):
|
|
def _print_ModularIndexing(self, expr):
|
|
x, div, mod = expr.args
|
|
x = self.paren(self.doprint(x))
|
|
div = self.paren(self.doprint(div))
|
|
mod = self.paren(self.doprint(mod))
|
|
if div != "1":
|
|
x = f"({x} // {div})"
|
|
return f"{x} % {mod}"
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
x, div = expr.args
|
|
x = self.paren(self.doprint(x))
|
|
div = self.paren(self.doprint(div))
|
|
return f"({x} // {div})"
|
|
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.floor({self._print(expr.args[0])})"
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.ceil({self._print(expr.args[0])})"
|
|
|
|
|
|
class OpOverrides:
|
|
def __init__(self, parent):
|
|
super().__init__()
|
|
self._parent = parent
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self._parent, item)
|
|
|
|
@staticmethod
|
|
def identity(value):
|
|
# used to trigger cse
|
|
return value
|
|
|
|
@staticmethod
|
|
def constant(value, dtype):
|
|
return repr(value)
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return ops.div("1", x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return ops.mul(x, x)
|
|
|
|
@staticmethod
|
|
def sign(x):
|
|
left = ops.where(ops.lt("0", x), "1", "0")
|
|
right = ops.where(ops.lt(x, "0"), "1", "0")
|
|
return ops.sub(left, right)
|
|
|
|
@staticmethod
|
|
def bitwise_not(x):
|
|
return f"~{ExprPrinter.paren(x)}"
|
|
|
|
@staticmethod
|
|
def logical_not(a):
|
|
return f"{ExprPrinter.paren(a)} == 0"
|
|
|
|
@staticmethod
|
|
def bitwise_and(x, y):
|
|
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_or(x, y):
|
|
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_xor(x, y):
|
|
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_left_shift(x, y):
|
|
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
|
|
|
# TODO(fdrocha): this is currently not being used anywhere,
|
|
# pending on moving triton pin past 972b761
|
|
@staticmethod
|
|
def bitwise_right_shift(x, y):
|
|
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def remainder(a, b):
|
|
r = ops.mod(a, b)
|
|
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
|
|
|
|
|
|
class DeferredLine(DeferredLineBase):
|
|
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
|
|
|
def __init__(self, name, line):
|
|
super().__init__(line)
|
|
self.name = name
|
|
|
|
def __call__(self):
|
|
if (
|
|
self.name not in V.graph.removed_buffers
|
|
and self.name not in V.graph.inplaced_to_remove
|
|
):
|
|
return self.line
|
|
return None
|
|
|
|
def _new_line(self, line):
|
|
return DeferredLine(self.name, line)
|
|
|
|
|
|
class BracesBuffer(IndentedBuffer):
|
|
def indent(self, offset=1):
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
for _ in range(offset):
|
|
self.writeline("{")
|
|
self._indent += 1
|
|
for _ in range(-offset):
|
|
self._indent -= 1
|
|
self.writeline("}")
|
|
yield
|
|
for _ in range(-offset):
|
|
self.writeline("{")
|
|
self._indent += 1
|
|
for _ in range(offset):
|
|
self._indent -= 1
|
|
self.writeline("}")
|
|
|
|
return ctx()
|
|
|
|
|
|
class InplacedBuffer(typing.NamedTuple):
|
|
inner_name: str
|
|
other_names: typing.List[str]
|
|
|
|
|
|
class KernelArgs:
|
|
@staticmethod
|
|
def _lookup(prefix, odict, name):
|
|
assert isinstance(name, (str, sympy.Symbol))
|
|
if name not in odict:
|
|
odict[name] = f"{prefix}{len(odict)}"
|
|
return odict[name]
|
|
|
|
def __init__(self, sizevars=None):
|
|
self.input_buffers = dict()
|
|
self.output_buffers = dict()
|
|
self.inplace_buffers = dict()
|
|
self.sizevars = sizevars or dict()
|
|
|
|
def __repr__(self):
|
|
return "KernelArgs({})".format(
|
|
", ".join(
|
|
map(
|
|
repr,
|
|
[
|
|
self.input_buffers,
|
|
self.output_buffers,
|
|
self.inplace_buffers,
|
|
self.sizevars,
|
|
],
|
|
)
|
|
)
|
|
)
|
|
|
|
def input(self, name):
|
|
if V.graph.scheduler:
|
|
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
|
assert name not in V.graph.removed_buffers, name
|
|
if name in self.output_buffers:
|
|
return self.output_buffers[name]
|
|
if name in self.inplace_buffers:
|
|
return self.inplace_buffers[name].inner_name
|
|
if name.startswith("seed"):
|
|
return self._lookup("seed", self.input_buffers, name)
|
|
return self._lookup("in_ptr", self.input_buffers, name)
|
|
|
|
def output(self, name):
|
|
if V.graph.scheduler:
|
|
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
|
assert name not in V.graph.removed_buffers, name
|
|
if name in self.inplace_buffers:
|
|
return self.inplace_buffers[name].inner_name
|
|
return self._lookup("out_ptr", self.output_buffers, name)
|
|
|
|
def make_inplace(self, input_name, output_name):
|
|
assert output_name not in self.inplace_buffers
|
|
if input_name in self.inplace_buffers:
|
|
buf = self.inplace_buffers[input_name]
|
|
buf.other_names.append(output_name)
|
|
self.inplace_buffers[output_name] = buf
|
|
else:
|
|
buf = InplacedBuffer(
|
|
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
|
|
[input_name, output_name],
|
|
)
|
|
self.inplace_buffers[input_name] = buf
|
|
self.inplace_buffers[output_name] = buf
|
|
|
|
def size(self, name):
|
|
if str(name) == "seed":
|
|
self.sizevars["seed"] = "seed"
|
|
return "seed"
|
|
return self._lookup("ks", self.sizevars, name)
|
|
|
|
def call_names(self):
|
|
return chain(
|
|
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
|
)
|
|
|
|
def wrap_ptr_arg(self, buf, dtype):
|
|
return f"c_void_p({buf}.data_ptr())"
|
|
|
|
def wrap_size_arg(self, size):
|
|
return f"c_long({size})"
|
|
|
|
def cpp_argdefs(self):
|
|
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
|
|
|
|
# TODO(jansel): replace this with data from scheduler
|
|
buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
|
|
for name, val in V.graph.graph_inputs.items():
|
|
if isinstance(val, sympy.Expr):
|
|
if val.is_integer:
|
|
buffer_types[name] = torch.int64
|
|
else:
|
|
buffer_types[name] = torch.float64
|
|
else:
|
|
buffer_types[name] = val.get_dtype()
|
|
buffer_types.update(
|
|
{name: val.dtype for name, val in V.graph.constants.items()}
|
|
)
|
|
|
|
call_args = []
|
|
arg_defs = []
|
|
arg_types = []
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
outer = inplaced.other_names[-1]
|
|
inner = inplaced.inner_name
|
|
dtype = buffer_types[outer]
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"{cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"{cpp_dtype}*")
|
|
for outer, inner in self.input_buffers.items():
|
|
if outer in self.inplace_buffers:
|
|
continue
|
|
dtype = buffer_types[outer]
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"const {cpp_dtype}*")
|
|
for outer, inner in self.output_buffers.items():
|
|
if outer in self.inplace_buffers or inner == "REMOVED":
|
|
continue
|
|
dtype = buffer_types[outer]
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"{cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"{cpp_dtype}*")
|
|
for outer, inner in self.sizevars.items():
|
|
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
|
call_args.append(self.wrap_size_arg(outer))
|
|
arg_types.append(f"const {INDEX_TYPE}")
|
|
return arg_defs, call_args, arg_types
|
|
|
|
def python_argdefs(self):
|
|
arg_defs = []
|
|
call_args = []
|
|
precompile_args = []
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
arg_defs.append(inplaced.inner_name)
|
|
call_args.append(inplaced.other_names[-1])
|
|
precompile_args.append(
|
|
TensorArg(
|
|
inplaced.inner_name,
|
|
inplaced.other_names[-1],
|
|
V.graph.get_dtype(inplaced.other_names[-1]),
|
|
)
|
|
)
|
|
for outer, inner in chain(
|
|
self.input_buffers.items(), self.output_buffers.items()
|
|
):
|
|
if outer in self.inplace_buffers or inner == "REMOVED":
|
|
continue
|
|
arg_defs.append(inner)
|
|
call_args.append(outer)
|
|
precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
|
|
for outer, inner in self.sizevars.items():
|
|
arg_defs.append(inner)
|
|
call_args.append(str(outer))
|
|
precompile_args.append(SizeArg(inner, outer))
|
|
|
|
return arg_defs, call_args, precompile_args
|
|
|
|
def aliases(self):
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
for other in inplaced.other_names:
|
|
if other in V.graph.inplaced_to_remove:
|
|
continue
|
|
if other in self.input_buffers:
|
|
yield self.input_buffers[other], inplaced.inner_name
|
|
if other in self.output_buffers:
|
|
yield self.output_buffers[other], inplaced.inner_name
|
|
|
|
def is_removed(self, name):
|
|
def _is_removed(name, buffers):
|
|
return name not in buffers or buffers[name] == "REMOVED"
|
|
|
|
return _is_removed(name, self.output_buffers) and _is_removed(
|
|
name, self.inplace_buffers
|
|
)
|
|
|
|
|
|
class CSEVariable:
|
|
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
|
The backends can inherit from this class and overload the "create_cse_var" Kernel to do that.
|
|
The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py.
|
|
"""
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.name)
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return type(other) == type(self) and other.name == self.name
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
pass
|
|
|
|
|
|
class CppWrapperKernelArgs(KernelArgs):
|
|
def wrap_ptr_arg(self, buf, dtype):
|
|
from .cpp import DTYPE_TO_CPP
|
|
|
|
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
|
|
|
|
def wrap_size_arg(self, size):
|
|
return f"{size}"
|
|
|
|
|
|
class CSE:
|
|
"""Common subexpression elimination"""
|
|
|
|
def __init__(
|
|
self,
|
|
prefix="",
|
|
suffix="",
|
|
name_prefix="tmp",
|
|
iter_buffers=None,
|
|
store_cache=None,
|
|
reduction_cache=None,
|
|
varname_map=None,
|
|
):
|
|
self.prefix = prefix
|
|
self.suffix = suffix
|
|
self.cache = {}
|
|
self.name_prefix = name_prefix
|
|
self.store_cache = store_cache or {}
|
|
self.reduction_cache = reduction_cache or {}
|
|
self.iter_buffer_ids = iter_buffers or itertools.count()
|
|
self.invalidated_stores = set()
|
|
self.varname_map = varname_map or {}
|
|
|
|
def invalidate(self, keep_vars: typing.Set[str]):
|
|
for name, tmp in list(self.store_cache.items()):
|
|
if tmp not in keep_vars:
|
|
del self.store_cache[name]
|
|
self.invalidated_stores.add(name)
|
|
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
|
|
|
|
def clone(self):
|
|
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
|
return CSE(
|
|
prefix=self.prefix,
|
|
suffix=self.suffix,
|
|
name_prefix=self.name_prefix,
|
|
iter_buffers=self.iter_buffer_ids,
|
|
store_cache=self.store_cache,
|
|
varname_map=self.varname_map,
|
|
)
|
|
|
|
def generate(
|
|
self,
|
|
buffer: IndentedBuffer,
|
|
expr: typing.Union[str, CSEVariable],
|
|
write=True,
|
|
assignment=True,
|
|
) -> CSEVariable:
|
|
assert isinstance(expr, (str, CSEVariable)), type(expr)
|
|
assert write or assignment
|
|
if isinstance(expr, CSEVariable):
|
|
return expr
|
|
cache_key = expr
|
|
if cache_key not in self.cache:
|
|
var = self.newvar() if assignment else None
|
|
self.cache[cache_key] = var
|
|
if write:
|
|
if V.kernel.current_node:
|
|
V.kernel.current_node.codegen_originating_info(
|
|
buffer, only_once=True
|
|
)
|
|
if assignment:
|
|
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
|
else:
|
|
line = f"{expr}{self.suffix}"
|
|
buffer.writeline(line)
|
|
|
|
return self.cache[cache_key]
|
|
|
|
def newvar(self) -> CSEVariable:
|
|
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
|
var = V.kernel.create_cse_var(var_name)
|
|
self.varname_map[var_name] = var
|
|
return var
|
|
|
|
|
|
class CodeGen:
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.exit_stack = contextlib.ExitStack()
|
|
|
|
def __enter__(self):
|
|
self.exit_stack.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
class Kernel(CodeGen):
|
|
newvar_prefix = ""
|
|
suffix = ""
|
|
overrides = None
|
|
load_format = None
|
|
store_format = None
|
|
|
|
def __init__(self, args=None):
|
|
super().__init__()
|
|
metrics.generated_kernel_count += 1
|
|
self.args = args or KernelArgs()
|
|
self.loads = IndentedBuffer()
|
|
self.compute = IndentedBuffer()
|
|
self.stores = IndentedBuffer()
|
|
self.cse = CSE(self.newvar_prefix, self.suffix)
|
|
self.must_keep_buffers = set()
|
|
self.current_node = None
|
|
self.store_buffer_names = set()
|
|
|
|
@contextlib.contextmanager
|
|
def set_current_node(self, node):
|
|
prior = self.current_node
|
|
self.current_node = node
|
|
try:
|
|
yield
|
|
finally:
|
|
self.current_node = prior
|
|
|
|
@contextlib.contextmanager
|
|
def swap_buffers(self, lb, cb=None, sb=None):
|
|
if cb is None:
|
|
cb = lb
|
|
loads = self.loads
|
|
compute = self.compute
|
|
stores = self.stores
|
|
cse = self.cse
|
|
self.loads = lb
|
|
self.compute = cb
|
|
self.stores = sb
|
|
self.cse = cse.clone()
|
|
try:
|
|
yield
|
|
finally:
|
|
self.loads = loads
|
|
self.compute = compute
|
|
self.stores = stores
|
|
self.cse = cse
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
raise NotImplementedError()
|
|
|
|
def indirect_load(self, name: str, index: sympy.Expr):
|
|
"""A load the depends on an index we have read"""
|
|
prior = self.loads
|
|
try:
|
|
# put the load in the compute section as it might have deps
|
|
self.loads = self.compute
|
|
return self.load(name, index)
|
|
finally:
|
|
self.loads = prior
|
|
|
|
def store(self, name, index, value, mode=None):
|
|
raise NotImplementedError()
|
|
|
|
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
|
raise NotImplementedError()
|
|
|
|
def __enter__(self):
|
|
class CSEProxy:
|
|
self.name = "CSEProxy"
|
|
|
|
@staticmethod
|
|
def __getattr__(name):
|
|
def inner(*args, **kwargs):
|
|
csevar = self.cse.generate(
|
|
self.compute, getattr(parent_handler, name)(*args, **kwargs)
|
|
)
|
|
csevar.update_on_args(name, args, kwargs)
|
|
return csevar
|
|
|
|
return inner
|
|
|
|
@staticmethod
|
|
def indirect_indexing(index_var, size):
|
|
return sympy_symbol(str(index_var))
|
|
|
|
@staticmethod
|
|
def load(name: str, index: sympy.Expr):
|
|
if name in self.cse.invalidated_stores:
|
|
# A load from an invalidated store requires us to
|
|
# keep the actual buffer around
|
|
V.kernel.must_keep_buffers.add(name)
|
|
if free_symbol_startswith(index, "tmp"):
|
|
return self.indirect_load(name, index)
|
|
store_cache = self.cse.store_cache
|
|
if name in store_cache:
|
|
return store_cache[name]
|
|
return self.load(name, index)
|
|
|
|
@staticmethod
|
|
def store(name, index, value, mode=None):
|
|
self.store_buffer_names.add(name)
|
|
if mode is None:
|
|
self.cse.store_cache[name] = value
|
|
if self.current_node:
|
|
for other_name in self.current_node.get_mutations():
|
|
self.cse.store_cache[other_name] = value
|
|
if name not in V.graph.removed_buffers:
|
|
return self.store(name, index, value, mode=mode)
|
|
|
|
@staticmethod
|
|
def reduction(name, dtype, src_dtype, reduction_type, index, value):
|
|
self.store_buffer_names.add(name)
|
|
return self.reduction(
|
|
name, dtype, src_dtype, reduction_type, index, value
|
|
)
|
|
|
|
super().__enter__()
|
|
parent_handler = self.overrides(V.get_ops_handler())
|
|
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
|
|
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if V.graph.scheduler:
|
|
V.graph.scheduler.remove_kernel_local_buffers()
|
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
def rename_indexing(self, index) -> sympy.Expr:
|
|
# adds the necessary kernel args for index expressions
|
|
# and renames variables in index expressions to kernel arg names
|
|
if isinstance(index, (list, tuple)):
|
|
return [self.rename_indexing(x) for x in index]
|
|
index = V.graph.sizevars.simplify(index)
|
|
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
|
replacements = {
|
|
x: self.args.size(x)
|
|
for x in sorted_symbols
|
|
if x.name.startswith("s") or x.name.startswith("ps")
|
|
}
|
|
return sympy_subs(index, replacements)
|
|
|
|
def create_cse_var(self, *args, **kwargs):
|
|
return CSEVariable(*args, **kwargs)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OptimizationContext:
|
|
key: typing.ClassVar[str] = "opt_ctx"
|
|
|
|
# Load value as mask
|
|
is_load_as_mask: bool = False
|
|
# Load bfloat16 value as float32
|
|
is_load_bf16_as_fp32: bool = False
|
|
# Store float32 value as bfloat16
|
|
is_store_fp32_as_bf16: bool = False
|
|
# do not need type cast for
|
|
# for mem copy only node bf16 load -> bf16 store,
|
|
is_bf16_mem_copy: bool = False
|
|
|
|
dtype: torch.dtype = None
|
|
ops_name: str = ""
|
|
is_most_inner_loop_irrevelant: bool = False
|