import contextlib import dataclasses import functools import itertools import logging import operator import re from itertools import chain from typing import ( Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Set, Tuple, Union, ) import sympy from sympy.printing.printer import Printer import torch import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges from .. import config, metrics from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") def data_type_logger(msg): if schedule_log.isEnabledFor(logging.DEBUG): schedule_log.debug("Data type propagation: %s", msg) @dataclasses.dataclass class WorkspaceArg: """A temporary buffer used for a single kernel, then discarded. Not registered as a traditional buffer since there are no users, so it would be dead code eliminated. """ nbytes: sympy.Expr zero_fill: bool @dataclasses.dataclass class TensorArg: name: str buffer: str dtype: torch.dtype offset: sympy.Expr = sympy.Integer(0) @dataclasses.dataclass class SizeArg: name: str expr: sympy.Expr @dataclasses.dataclass class DeviceCodegen: scheduling: Any wrapper_codegen: type cpp_wrapper_codegen: type = type(None) KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] device_codegens: Dict[str, DeviceCodegen] = {} class DeviceOpOverrides: def import_get_raw_stream_as(self, name): raise NotImplementedError def set_device(self, device_idx): raise NotImplementedError def synchronize(self): raise NotImplementedError def device_guard(self, device_idx): raise NotImplementedError device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} # The code generated by Inductor consists of two main parts: kernel code and wrapper code. # For any new backend looking to integrate with Inductor, customization of these two main # parts are necessary to generate its specific code. # # Kernel code generation is determined by different Scheduling. Consequently, a new # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. # # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, # and override specific member functions to create backend-specific Python wrapper code. # # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces # provide flexibility to the backend. A backend can choose to implement these classes from scratch, # or reuse them by extending and overriding as necessary. And Inductor provides the registration API, # register_backend_for_device, to equip a new backend at runtime. # # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. # This backend can be used as a reference: # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 def register_backend_for_device( device: str, device_scheduling: type, device_wrapper_codegen: type, device_cpp_wrapper_codegen: type = type(None), ): device_codegens[device] = DeviceCodegen( device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen ) def get_scheduling_for_device(device: str): return device_codegens[device].scheduling if device in device_codegens else None def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): if device in device_codegens: wrapper_codegen_obj: DeviceCodegen = device_codegens[device] return ( wrapper_codegen_obj.cpp_wrapper_codegen if cpp_wrapper else wrapper_codegen_obj.wrapper_codegen ) else: return None def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): from ..ir import FlexibleLayout # added contiguous index prevents reordering return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): device_op_overrides_dict[device] = device_op_overrides def get_device_op_overrides(device: str): assert isinstance(device, str) if not device_op_overrides_dict.keys(): from .cuda import device_op_overrides # noqa: F401 from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 if device in device_op_overrides_dict.keys(): return device_op_overrides_dict[device] @functools.lru_cache(None) def boolean_ops(): return ( "is_inf", "is_nan", "bitwise_xor", "logical_not", "signbit", "le", "lt", "ge", "gt", "eq", "ne", ) DTYPE_TO_COMPUTATION_DTYPE = { torch.bfloat16: torch.float, torch.float16: torch.float, **{ dtype: dtype for dtype in [ torch.bool, torch.float32, torch.float64, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.uint16, torch.uint32, torch.uint64, ] }, } class DataTypePropagation: def __init__(self, body) -> None: self.body = body self.graphs: Dict[Union[Callable[..., Any], str], Any] = { "root": body.root_block.graph } for k, v in body.subblocks.items(): self.graphs[k] = v.graph def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): inputs = node.all_input_nodes input_nodes = [ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" ] if len(input_nodes) == 0: return None all_input_nodes_propagated = all( OptimizationContext.key in n.meta and n.meta[OptimizationContext.key].dtype is not None for n in input_nodes ) if not all_input_nodes_propagated: return None return functools.reduce( torch.promote_types, [n.meta[OptimizationContext.key].dtype for n in input_nodes], ) def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): sub_graph = self.graphs[node.target] dtype = self.propagate_graph(sub_graph) assert dtype return dtype def deduce_node_dtype(self, node: torch.fx.Node): if node.target in boolean_ops(): return torch.bool if node.op == "placeholder": return None if node.target == "output": # we can infer output node if it only have 1 arg if len(node.args) != 1: return None if node.target in ( "to_dtype", "index_expr", ): return node.args[-1] if node.target in ( "rand", "randn", ): return torch.float if node.target in ( "get_index", "index_expr", "randint64", ): return torch.int64 if node.target in ( "load", "store", "store_reduction", ): buf_name = node.args[1] return V.graph.get_dtype(buf_name) # type: ignore[arg-type] if node.target == operator.getitem: return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] assert isinstance(node.target, str) if node.target == "reduction": return node.args[1] if node.target == "constant": return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index] if node.target.startswith("masked_subblock"): return self.deduce_node_dtype_by_subgraph(node) return self.deduce_node_dtype_by_inputs(node) def propagate_graph(self, graph: torch.fx.Graph): assert graph.nodes graph_dtype = None # For masked_subblock, we use output's dtype to represent # the dtype of this subgraph. For other cases, graph_dtype # might be None for node in graph.nodes: if OptimizationContext.key in node.meta: opt_ctx = node.meta[OptimizationContext.key] else: opt_ctx = OptimizationContext() opt_ctx.dtype = self.deduce_node_dtype(node) node.meta[OptimizationContext.key] = opt_ctx if node.target == "output": graph_dtype = opt_ctx.dtype return graph_dtype def propagate(self): self.propagate_graph(self.graphs["root"]) @classmethod def propagate_loopbody(cls, body): return cls(body).propagate() @classmethod def propagate_scheduler_node(cls, node): from ..ir import LoopBody from ..scheduler import SchedulerNode assert isinstance(node, SchedulerNode) assert isinstance(node._body, LoopBody) DataTypePropagation.propagate_loopbody(node._body) class ExprPrinter(Printer): @staticmethod def paren(string): def all_in_parens(string): if string[0] != "(" or len(string) < 2: return False count = 1 for i, char in enumerate(string[1:]): if char == "(": count += 1 elif char == ")": count -= 1 if count == 0 and i != len(string) - 2: return False assert count == 0 return True if ( isinstance(string, CSEVariable) or re.match(r"^[a-z0-9_.]+$", string, re.I) or re.match(r"^\([^)]*\)$", string, re.I) or string == "" ): return string # don't put extra parens for strings that are already wrapped in parens if all_in_parens(string): return string return f"({string})" def _print_Infinity(self, expr): return "math.inf" def _print_NegativeInfinity(self, expr): return "-math.inf" def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) def _print_Mul(self, expr): return "*".join(map(self.paren, map(self._print, expr.args))) def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) def _print_FloorDiv(self, expr): raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) def _print_GreaterThan(self, expr): # GreaterThan: >= # StrictlyGreaterThan: > # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" class PythonPrinter(ExprPrinter): def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) mod = self.paren(self.doprint(mod)) if div != "1": x = f"({x} // {div})" return f"{x} % {mod}" def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) def _print_Pow(self, expr): # Pow() confuses triton base, exp = expr.args # NB: Remember this is sizevar computation! You don't typically # expect to have to do floating point computation including exponents # in sizevar compute. Instead of adding support for floating # point pow, you should make upstream retranslate the Sympy expression # into Tensor expressions earlier and do that instead. if exp == 0.5: return self._helper_sqrt(base) elif exp == -0.5: return "1/" + self._helper_sqrt(base) base = self._print(base) assert exp == int(exp), exp exp = int(exp) if exp > 0: return "*".join([self.paren(base)] * exp) elif exp < 0: return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) else: # exp == 0 return "1" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" def _print_Trunc(self, expr): assert len(expr.args) == 1 return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" def _print_Min(self, expr): assert len(expr.args) >= 2 return f"min({', '.join(map(self._print, expr.args))})" def _print_OpaqueUnaryFn_cos(self, expr): assert len(expr.args) == 1 return f"math.cos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_cosh(self, expr): assert len(expr.args) == 1 return f"math.cosh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_acos(self, expr): assert len(expr.args) == 1 return f"math.acos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sin(self, expr): assert len(expr.args) == 1 return f"math.sin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sinh(self, expr): assert len(expr.args) == 1 return f"math.sinh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_asin(self, expr): assert len(expr.args) == 1 return f"math.asin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tan(self, expr): assert len(expr.args) == 1 return f"math.tan({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tanh(self, expr): assert len(expr.args) == 1 return f"math.tanh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" def _print_Round(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 number, ndigits = expr.args assert isinstance(ndigits, sympy.Integer) return f"round({self._print(number)}, {ndigits})" class OpOverrides: def __init__(self, parent): super().__init__() self._parent = parent def __getattr__(self, item): return getattr(self._parent, item) @staticmethod def identity(value): # used to trigger cse return value @staticmethod def constant(value, dtype): return repr(value) @staticmethod def reciprocal(x): return ops.truediv(ops.constant(1, torch.int32), x) @staticmethod def square(x): return ops.mul(x, x) @staticmethod def bitwise_not(x): return f"~{ExprPrinter.paren(x)}" @staticmethod def logical_not(a): return f"{ExprPrinter.paren(a)} == 0" @staticmethod def bitwise_and(x, y): return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" @staticmethod def bitwise_or(x, y): return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" @staticmethod def bitwise_xor(x, y): return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" @staticmethod def bitwise_left_shift(x, y): return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" @staticmethod def bitwise_right_shift(x, y): return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" @staticmethod def remainder(a, b): r = ops.mod(a, b) cond = ops.and_( ops.ne(r, ops.constant(0, torch.int32)), ops.ne(ops.signbit(r), ops.signbit(b)), ) return ops.where(cond, ops.add(r, b), r) @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) @classmethod def _initialize_pointwise_overrides(cls, target): assert target in {"triton", "cpp", "cppvec"}, target for funcname, data in pointwise_overrides_data.items(): impl = getattr(data, target) if impl is None: continue setattr(cls, funcname, staticmethod(impl)) @dataclasses.dataclass class OverridesData: name: str cpp: Callable[..., str] # None when not impl in libdevice/triton triton: Optional[Callable[..., str]] = None # None when not impl in aten/.../vec cppvec: Optional[Callable[..., str]] = None type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) # NB: if you add a new special function, don't forget to update # torch._inductor.ops_handler too pointwise_overrides_data: Dict[str, OverridesData] = dict( airy_ai=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"airy_ai_forward({x})", name="special_airy_ai", ), bessel_j0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"bessel_j0_forward({x})", triton=lambda x: f"libdevice.j0({x})", name="special_bessel_j0", ), bessel_j1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"bessel_j1_forward({x})", triton=lambda x: f"libdevice.j1({x})", name="special_bessel_j1", ), bessel_y0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"bessel_y0_forward({x})", triton=lambda x: f"libdevice.y0({x})", name="special_bessel_y0", ), bessel_y1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"bessel_y1_forward({x})", triton=lambda x: f"libdevice.y1({x})", name="special_bessel_y1", ), digamma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_digamma({x})", cppvec=lambda x: f"{x}.digamma()", name="digamma", ), # no cpp nor triton implementation for entr, it is defined as decomposition # erf, erfc erfcx=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_erfcx({x})", triton=lambda x: f"libdevice.erfcx({x})", name="special_erfcx", ), fma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", name="fma", ), # erfinv, exp2, expit, gammaln igamma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"calc_igamma({x}, {y})", name="igamma", ), igammac=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"calc_igammac({x}, {y})", name="igammac", ), gammainc=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"calc_igamma({x}, {y})", name="special_gammainc", ), gammaincc=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"calc_igammac({x}, {y})", name="special_gammaincc", ), i0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_i0({x})", triton=lambda x: f"libdevice.cyl_bessel_i0({x})", cppvec=lambda x: f"{x}.i0()", name="i0", ), i0e=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_i0e({x})", cppvec=lambda x: f"{x}.i0e()", name="special_i0e", ), i1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_i1({x})", triton=lambda x: f"libdevice.cyl_bessel_i1({x})", name="special_i1", ), i1e=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_i1e({x})", name="special_i1e", ), log_ndtr=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_log_ndtr({x})", name="special_log_ndtr", ), # logit modified_bessel_i0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"modified_bessel_i0_forward({x})", triton=lambda x: f"libdevice.cyl_bessel_i0({x})", name="special_modified_bessel_i0", ), modified_bessel_i1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"modified_bessel_i1_forward({x})", triton=lambda x: f"libdevice.cyl_bessel_i1({x})", name="special_modified_bessel_i1", ), modified_bessel_k0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"modified_bessel_k0_forward({x})", name="special_modified_bessel_k0", ), modified_bessel_k1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"modified_bessel_k1_forward({x})", name="special_modified_bessel_k1", ), # multigamma ndtr=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_ndtr({x})", name="special_ndtr", ), ndtri=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"calc_ndtri({x})", name="special_ndtri", ), polygamma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"calc_polygamma({y}, {x})", name="polygamma", ), # psi - alias to digamma # round scaled_modified_bessel_k0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", name="special_scaled_modified_bessel_k0", ), scaled_modified_bessel_k1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", name="special_scaled_modified_bessel_k1", ), # sinc spherical_bessel_j0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x: f"spherical_bessel_j0_forward({x})", name="special_spherical_bessel_j0", ), zeta=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"zeta({x}, {y})", name="special_zeta", ), chebyshev_polynomial_t=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", name="special_chebyshev_polynomial_t", ), chebyshev_polynomial_u=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", name="special_chebyshev_polynomial_u", ), chebyshev_polynomial_v=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", name="special_chebyshev_polynomial_v", ), chebyshev_polynomial_w=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", name="special_chebyshev_polynomial_w", ), legendre_polynomial_p=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", name="special_legendre_polynomial_p", ), shifted_chebyshev_polynomial_t=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_t", ), shifted_chebyshev_polynomial_u=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_u", ), shifted_chebyshev_polynomial_v=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_v", ), shifted_chebyshev_polynomial_w=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_w", ), hermite_polynomial_h=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", name="special_hermite_polynomial_h", ), hermite_polynomial_he=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", name="special_hermite_polynomial_he", ), laguerre_polynomial_l=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", name="special_laguerre_polynomial_l", ), ) # Use mypy to check protocol implemented correctly def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: return h class DeferredLine(DeferredLineBase): """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" def __init__(self, name, line): super().__init__(line) self.name = name assert not isinstance(line, DeferredLineBase) def __call__(self): if all( self.name not in x for x in ( V.graph.removed_buffers, V.kernel.removed_buffers, V.graph.inplaced_to_remove, V.kernel.inplaced_to_remove, ) ): return self.line return None def _new_line(self, line): return DeferredLine(self.name, line) class BracesBuffer(IndentedBuffer): def indent(self, offset=1): @contextlib.contextmanager def ctx(): for _ in range(offset): self.writeline("{") self._indent += 1 for _ in range(-offset): self._indent -= 1 self.writeline("}") yield for _ in range(-offset): self.writeline("{") self._indent += 1 for _ in range(offset): self._indent -= 1 self.writeline("}") return ctx() class InplacedBuffer(NamedTuple): inner_name: str other_names: List[str] class KernelArgs: @staticmethod def _lookup(prefix, odict, name): assert isinstance(name, (str, sympy.Symbol)) if name not in odict: odict[name] = f"{prefix}{len(odict)}" return odict[name] def __init__(self, sizevars=None): self.input_buffers = dict() self.output_buffers = dict() self.inplace_buffers = dict() self.sizevars = sizevars or dict() self.workspace_arg = None def __repr__(self): return "KernelArgs({})".format( ", ".join( map( repr, [ self.input_buffers, self.output_buffers, self.inplace_buffers, self.sizevars, ], ) ) ) def _buffer_is_marked_removed(self, name): return isinstance(name, str) and name.startswith("REMOVED") def input(self, name): if V.graph.scheduler: name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name if name in self.output_buffers: return self.output_buffers[name] if name in self.inplace_buffers: return self.inplace_buffers[name].inner_name if name.startswith("seed"): return self._lookup("seed", self.input_buffers, name) return self._lookup("in_ptr", self.input_buffers, name) def output(self, name): if V.graph.scheduler: name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name if name in self.inplace_buffers: return self.inplace_buffers[name].inner_name return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name, output_name): assert output_name not in self.inplace_buffers if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] buf.other_names.append(output_name) self.inplace_buffers[output_name] = buf else: buf = InplacedBuffer( f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", [input_name, output_name], ) self.inplace_buffers[input_name] = buf self.inplace_buffers[output_name] = buf def workspace(self, nbytes: sympy.Expr, zero_fill: bool): if self.workspace_arg is None: self.workspace_arg = WorkspaceArg(nbytes, zero_fill) return "ws_ptr", 0 offset = self.workspace_arg.nbytes zero_fill = zero_fill or self.workspace_arg.zero_fill self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) return "ws_ptr", offset def seed_offset(self, name, value): if value in self.sizevars: return self.sizevars[value] if name in self.sizevars.values(): name = ( f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" ) self.sizevars[value] = name return name def size(self, name): if str(name) == "seed": self.sizevars["seed"] = "seed" return "seed" return self._lookup("ks", self.sizevars, name) def call_names(self): return chain( self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() ) def wrap_ptr_arg(self, buf, dtype): return buf def wrap_size_arg(self, size): return str(size) def cpp_argdefs(self): from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE call_args = [] arg_defs = [] arg_types = [] for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue outer = inplaced.other_names[-1] inner = inplaced.inner_name dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"{cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.input_buffers.items(): if outer in self.inplace_buffers: continue dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"const {cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"const {cpp_dtype}*") for outer, inner in self.output_buffers.items(): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"{cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): arg_defs.append(f"const {INDEX_TYPE} {inner}") call_args.append(self.wrap_size_arg(outer)) arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert self.workspace_arg is None, "Workspace not supported on CPU " return arg_defs, call_args, arg_types def python_argdefs(self): arg_defs = [] call_args = [] arg_types = [] precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue arg_defs.append(inplaced.inner_name) call_args.append(inplaced.other_names[-1]) arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) precompile_args.append( TensorArg( name=inplaced.inner_name, buffer=inplaced.other_names[-1], dtype=V.graph.get_dtype(inplaced.other_names[-1]), ) ) for outer, inner in chain( self.input_buffers.items(), self.output_buffers.items() ): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue arg_defs.append(inner) call_args.append(outer) arg_types.append(V.graph.get_dtype(outer)) precompile_args.append( TensorArg( name=inner, buffer=outer, dtype=V.graph.get_dtype(outer), ) ) for outer, inner in self.sizevars.items(): arg_defs.append(inner) call_args.append(outer) arg_types.append(type(outer)) precompile_args.append(SizeArg(inner, outer)) if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) if self.workspace_arg is not None: arg_defs.append("ws_ptr") call_args.append("workspace") precompile_args.append(self.workspace_arg) return arg_defs, call_args, precompile_args, arg_types def aliases(self): for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue for other in inplaced.other_names: if ( other in V.graph.inplaced_to_remove or other in V.kernel.inplaced_to_remove ): continue if other in self.input_buffers: yield self.input_buffers[other], inplaced.inner_name if other in self.output_buffers: yield self.output_buffers[other], inplaced.inner_name def is_removed(self, name): def _is_removed(name, buffers): return name not in buffers or self._buffer_is_marked_removed(buffers[name]) return _is_removed(name, self.output_buffers) and _is_removed( name, self.inplace_buffers ) # Includes inplace buffers, excludes removed buffers. Essentially, # after you do a call into this kernel, which buffers actually contain # updated data? Modeled off of python_argdefs. def live_output_buffers(self): live_outs = set() for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue live_outs.add(inplaced.other_names[-1]) for outer, inner in self.output_buffers.items(): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue live_outs.add(outer) return live_outs class CSEVariable: """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. To do so, the backends can simply overload `Kernel.create_cse_var` The "CSEVariable.update_on_args" method gives you a hook for annotations See example of TritonCSEVariable in triton.py """ def __init__(self, name, bounds: ValueRanges[Any]): assert isinstance(bounds, ValueRanges) self.name = name self.bounds = bounds def __str__(self): return self.name def __hash__(self) -> int: return hash(self.name) def __eq__(self, other) -> bool: return type(other) == type(self) and other.name == self.name def update_on_args(self, name, args, kwargs): pass class CppWrapperKernelArgs(KernelArgs): def wrap_ptr_arg(self, buf, dtype): from .cpp_utils import DTYPE_TO_CPP if config.abi_compatible: # In the abi_compatible model, we just return the buf here. # We will form correct call args later in wrapper.generate_kernel_all. return buf else: return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" def wrap_size_arg(self, size): return f"{size}" class CSE: """Common subexpression elimination""" def __init__( self, prefix="", suffix="", name_prefix="tmp", iter_buffers=None, store_cache=None, reduction_cache=None, varname_map=None, ): self.prefix = prefix self.suffix = suffix self.cache = {} self.name_prefix = name_prefix self.store_cache = store_cache or {} self.reduction_cache = reduction_cache or {} self.iter_buffer_ids = iter_buffers or itertools.count() self.invalidated_stores = set() self.varname_map = varname_map or {} def invalidate(self, keep_vars: Set[str]): for name, tmp in list(self.store_cache.items()): if tmp not in keep_vars: del self.store_cache[name] self.invalidated_stores.add(name) self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} def clone(self): # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional return CSE( prefix=self.prefix, suffix=self.suffix, name_prefix=self.name_prefix, iter_buffers=self.iter_buffer_ids, store_cache=self.store_cache, varname_map=self.varname_map, ) def generate( self, buffer: IndentedBuffer, expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], *, bounds: ValueRanges[Any] = ValueRanges.unknown(), write=True, assignment=True, ) -> CSEVariable: if isinstance(expr, OpsValue): expr = expr.value assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) assert write or assignment if isinstance(expr, CSEVariable): # If the expressions were always created with all the information, we could # assert expr.bounds == bounds, but sometimes the expression is created # with the loose ValueRanges.unknown(), so we need to tighten the bounds expr.bounds = expr.bounds.tighten(bounds) return expr cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr var = self.cache.get(cache_key, None) if not var: var = self.newvar(bounds) if assignment else None self.cache[cache_key] = var if write: if V.kernel.current_node: V.kernel.current_node.codegen_originating_info( buffer, only_once=True ) if isinstance(expr, IndentedBuffer): if assignment: buffer.writeline(f"{self.prefix}{var} =") buffer.splice(expr) buffer.writeline(self.suffix) else: if assignment: line = f"{self.prefix}{var} = {expr}{self.suffix}" else: line = f"{expr}{self.suffix}" buffer.writeline(line) else: var.bounds = var.bounds.tighten(bounds) return var def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" var = V.kernel.create_cse_var(var_name, bounds) self.varname_map[var_name] = var return var class IndirectAssertLine(DeferredLineBase): def __init__(self, line, indirect_assert, var, mask, size_map): super().__init__(line) self.var = var self.mask = mask self.indirect_assert = indirect_assert self.size_map = size_map def __call__(self): size, size_str = self.size_map[(self.var, self.mask)] # We assert if we've not been able to prove the bound assert_min = (self.var.bounds.lower >= 0) != sympy.true assert_max = (self.var.bounds.upper < size) != sympy.true lower = None upper = None if not (assert_min or assert_max): return None elif assert_min and assert_max: lower = "0" upper = size_str elif assert_min: lower = "0" else: assert assert_max upper = size_str return self.line.format( assert_line=self.indirect_assert(self.var, lower, upper, self.mask) ) def _new_line(self, line): return IndirectAssertLine( line, self.indirect_assert, self.var, self.mask, self.size_map ) class CodeGen: def __init__(self): super().__init__() self.exit_stack = contextlib.ExitStack() def __enter__(self): self.exit_stack.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): self.exit_stack.__exit__(exc_type, exc_val, exc_tb) class ScopedDict: def __init__(self, original_dict): self.original_dict = original_dict self.new_items = {} def __getitem__(self, key): if key in self.new_items: return self.new_items[key] return self.original_dict[key] def __setitem__(self, key, value): self.new_items[key] = value def __contains__(self, key): return key in self.new_items or key in self.original_dict def get(self, key, default=None): if key in self.new_items: return self.new_items[key] return self.original_dict.get(key, default) class Kernel(CodeGen): newvar_prefix = "" suffix = "" overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None # TODO: these look dead, but with all the getattr it's hard to tell... load_format: None = None store_format: None = None def __init__(self, args=None, increase_kernel_count=True): super().__init__() if increase_kernel_count: metrics.generated_kernel_count += 1 self.args = args or KernelArgs() self.loads = IndentedBuffer() self.compute = IndentedBuffer() self.stores = IndentedBuffer() self.cse: CSE = CSE(self.newvar_prefix, self.suffix) self.must_keep_buffers = set() self.store_buffer_names = set() self._load_mask = None # set in set_current_node self.current_node = None self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None # Upper bounds for indirect_indexing and their str representation # NB: None, None is never stored in map, but it is the assumed # "not set" value for the dict self.indirect_max_sizes: Dict[ Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]] ] = {} self.removed_buffers = set() self.inplaced_to_remove = set() # key: the buffer to write # value: the buffer to read and whose memory can be reused for # the buffer specified by key self.inplace_update_buffers = dict() # Set minimum number of elements processed per thread. self.min_elem_per_thread = 1 self.kernel_name = None @contextlib.contextmanager def set_current_node(self, node): prior = self.current_node self.current_node = node self.node_to_bounds = node._body.bounds().get_bounds() try: yield finally: self.current_node = prior @contextlib.contextmanager def swap_buffers(self, lb, cb=None, sb=None): def scope_cse(cse): new_cse = cse.clone() new_cse.cache = ScopedDict(cse.cache) new_cse.reduction_cache = ScopedDict(cse.reduction_cache) new_cse.store_cache = ScopedDict(cse.store_cache) return new_cse if cb is None: cb = lb loads = self.loads compute = self.compute stores = self.stores cse = self.cse self.loads = lb self.compute = cb self.stores = sb self.cse = scope_cse(cse) try: yield finally: self.loads = loads self.compute = compute self.stores = stores self.cse = cse def load(self, name: str, index: sympy.Expr) -> CSEVariable: raise NotImplementedError def indirect_load(self, name: str, index: sympy.Expr): """A load the depends on an index we have read""" prior = self.loads try: # put the load in the compute section as it might have deps self.loads = self.compute return self.load(name, index) finally: self.loads = prior def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): raise NotImplementedError def store( self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> None: raise NotImplementedError def reduction( self, dtype: torch.dtype, src_dtype: torch.dtype, reduction_type: ReductionType, value: Union[CSEVariable, Tuple[CSEVariable, ...]], ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: raise NotImplementedError def scan( self, dtypes: Tuple[torch.dtype, ...], combine_fn: Callable[ [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] ], values: Tuple[CSEVariable, ...], ) -> Tuple[CSEVariable, ...]: raise NotImplementedError def bucketize( self, values: CSEVariable, offsets_name: str, offsets_size: sympy.Expr, indexing_dtype: torch.dtype, right: bool, ) -> CSEVariable: """ See [Note: Inductor bucketize op] """ raise NotImplementedError @property def assert_function(self) -> str: raise NotImplementedError def indirect_assert(self, var, lower, upper, mask=None): if lower and upper: # The conditions need to be in parens because of Python's operator precedence. # It'd be less error-prone to use and/or/not, which is suported by triton cond = f"({lower} <= {var}) & ({var} < {upper})" cond_print = f"{lower} <= {var} < {upper}" elif lower: cond = f"{lower} <= {var}" cond_print = cond else: assert upper cond = f"{var} < {upper}" cond_print = cond if mask: cond = f"({cond}) | ~{mask}" return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' def index_to_str(self, index: sympy.Expr) -> str: raise NotImplementedError def __enter__(self): # TODO: hoist this to top level class CSEProxy: self.name = "CSEProxy" vr_analysis = ValueRangeAnalysis() @staticmethod def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): bounds = CSEProxy._bound_variable(name, *args, **kwargs) value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] def do_cse(v): csevar = self.cse.generate(self.compute, v, bounds=bounds) csevar.update_on_args(name, args, kwargs) return csevar return pytree.tree_map(do_cse, value) return inner @staticmethod def _bound_variable(name, *args, **kwargs): """ If the variable comes from an FX node, we forward the bound we have already computed Else, if the variable when codegen'ing another op, we try to compute its bounds """ from ..select_algorithm import TritonTemplateKernel if isinstance(V.kernel, TritonTemplateKernel): return ValueRanges.unknown() fx_node = V.interpreter.current_node if fx_node.target == name: assert isinstance(self.node_to_bounds, dict) return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): # These create lots of inner strings. We would need to compute the bounds at the ops # We will also likely not get much from computing VRs on these nodes if any( s in fx_node.target for s in ("set_indirect", "reduction", "scan") ): return ValueRanges.unknown() # We assume that the inputs come from `ops.` and are not strings. If you want to generate # intermediary strings, wrap them in CSE variables with properly initialised bounds. # If there is no FX bound but we know how to compute one we do so assert not kwargs def arg_to_bound(x): if isinstance(x, CSEVariable): return x.bounds elif isinstance(x, sympy.Expr): return bound_sympy(x) else: return x arg_bounds = list(map(arg_to_bound, args)) return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) else: return ValueRanges.unknown() @staticmethod def indirect_indexing( var: CSEVariable, size: sympy.Expr, check: bool = True ): # Skip CSE since this doesn't return an expression if var.bounds.lower < 0: # type: ignore[operator] new_bounds = ValueRanges.unknown() if var.bounds != ValueRanges.unknown() and isinstance( size, sympy.Number ): # Take the negative part of the bound and add size to it # Then take union of that and the positive part # This is a tighter bound than that of a generic ops.where, as we have info on the cond neg = var.bounds & ValueRanges(-sympy.oo, -1) new_bounds = ValueRanges(neg.lower + size, neg.upper + size) # We don't have a good way of representing the empty range if var.bounds.upper >= 0: # type: ignore[operator] pos = var.bounds & ValueRanges(0, sympy.oo) new_bounds = new_bounds | pos stm = ops.add(var, self.rename_indexing(size)) # Mixed negative and non-negative if var.bounds.upper >= 0: # type: ignore[operator] lt = ops.lt(var, 0) stm = ops.where(lt, stm, var) new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) new_var.update_on_args("index_wrap", (var,), {}) var = new_var if self.generate_assert(check): mask = self.load_mask(var) # An assertion line may have been written already, if so just # update the max size. map_key = (var, mask) existing_size, _ = self.indirect_max_sizes.get( map_key, (None, None) ) if existing_size is not None: size = sympy.Min(size, existing_size) else: self.compute.writeline( IndirectAssertLine( "{assert_line}", self.indirect_assert, var, mask, self.indirect_max_sizes, ) ) self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) return parent_handler.indirect_indexing(var, size, check) @staticmethod def load(name: str, index: sympy.Expr) -> CSEVariable: if name in self.cse.invalidated_stores: # A load from an invalidated store requires us to # keep the actual buffer around V.kernel.must_keep_buffers.add(name) if free_symbol_is_type(index, SymT.TMP): return self.indirect_load(name, index) store_cache = self.cse.store_cache if name in store_cache: return store_cache[name] return self.load(name, index) @staticmethod def store( name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> None: self.store_buffer_names.add(name) if mode is None: self.cse.store_cache[name] = value if self.current_node: for other_name in self.current_node.get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: return self.store(name, index, value, mode=mode) else: return None # type: ignore[return-value] @staticmethod def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): self.store_buffer_names.add(name) self.cse.store_cache[name] = value if self.current_node: for other_name in self.current_node.get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: return self.store_reduction(name, index, value) @staticmethod def reduction( dtype: torch.dtype, src_dtype: torch.dtype, reduction_type: ReductionType, value: Union[CSEVariable, Tuple[CSEVariable, ...]], ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: return self.reduction(dtype, src_dtype, reduction_type, value) @staticmethod def scan( dtypes: Tuple[torch.dtype, ...], combine_fn: Callable[ [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...], ], values: Tuple[CSEVariable, ...], ) -> Tuple[CSEVariable, ...]: return self.scan(dtypes, combine_fn, values) @staticmethod def bucketize( values: CSEVariable, offsets_name: str, offsets_size: sympy.Expr, indexing_dtype: torch.dtype, right: bool, ) -> CSEVariable: """ [Note: Inductor bucketize op] Given values (tensor) and offsets_name (reference to the name of a 1D tensor), calculate the bucket that each value belongs to. e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True return = [ 0, 1, 1, 1, 1, 3, 3, 4]. When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. When right == True, bucket i refers to range [offsets[i], offsets[i+1]). Offsets must be non-decreasing or the result is undefined. """ return self.bucketize( values, offsets_name, offsets_size, indexing_dtype, right ) # Use mypy to check protocol implemented correctly def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: return h super().__enter__() assert self.overrides parent_handler = self.overrides(V.get_ops_handler()) self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) return self def __exit__(self, exc_type, exc_val, exc_tb): """ Note that V.graph.scheduler can be None when codegening triton template kernels. """ if V.graph.scheduler: V.graph.scheduler.remove_kernel_local_buffers() super().__exit__(exc_type, exc_val, exc_tb) def generate_assert(self, check): return (check or config.debug_index_asserts) and config.assert_indirect_indexing def load_mask(self, var) -> str: # only the triton kernel requires mask return "" def rename_indexing(self, index) -> sympy.Expr: # adds the necessary kernel args for index expressions # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): return [self.rename_indexing(x) for x in index] # type: ignore[return-value] index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { x: self.args.size(x) for x in sorted_symbols if symbol_is_type( x, ( SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, ), ) } return sympy_subs(index, replacements) def create_cse_var(self, *args, **kwargs): return CSEVariable(*args, **kwargs) @dataclasses.dataclass class OptimizationContext: key: ClassVar[str] = "opt_ctx" dtype: Optional[torch.dtype] = None ops_name: str = "" @functools.lru_cache(None) def jinja2_env(): try: import jinja2 return jinja2.Environment( undefined=jinja2.StrictUndefined, ) except ImportError: return None class KernelTemplate: """ Base class for defining kernel templates. Children classes: TritonTemplate, CUDATemplate """ @staticmethod def indent_except_first(source: str, num_indents: int, indents_spacing=4): lines = source.splitlines(True) if len(lines) > 1: lines[1:] = [ (" " * indents_spacing * num_indents) + line for line in lines[1:] ] return "".join(lines) @staticmethod def _template_from_string(source): env = jinja2_env() if env is not None: env.filters["indent_except_first"] = KernelTemplate.indent_except_first return env.from_string(source) return None @staticmethod def _fake_get_dtype(fake_out): _get_dtype_real = V.graph.get_dtype def get_dtype(name): if name == fake_out.get_name(): return fake_out.get_dtype() return _get_dtype_real(name) return get_dtype def __init__(self, name: str): self.name = name def maybe_append_choice(self, choices, **kwargs): """ Maybe generates a new ChoiceCaller and appends it into existing choices. choices: A list of ChoiceCallers. kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. """ try: choices.append(self.generate(**kwargs)) except NotImplementedError: pass def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller": """ Generates a ChoiceCaller instance from the given arguments. """ raise NotImplementedError