pytorch/torch/_inductor/codegen/common.py
PyTorch MergeBot 657d39e44c Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit 57108d9a49.

Reverted https://github.com/pytorch/pytorch/pull/126019 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it has a land race and failing in trunk 2ac33a9f66 ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2126016522))
2024-05-23 01:13:29 +00:00

1831 lines
63 KiB
Python

import contextlib
import dataclasses
import functools
import itertools
import logging
import operator
import re
from itertools import chain
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Union,
)
import sympy
from sympy.printing.printer import Printer
import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .. import config, metrics
from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
def data_type_logger(msg):
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Data type propagation: %s", msg)
@dataclasses.dataclass
class WorkspaceArg:
"""A temporary buffer used for a single kernel, then discarded.
Not registered as a traditional buffer since there are no users,
so it would be dead code eliminated.
"""
nbytes: sympy.Expr
zero_fill: bool
@dataclasses.dataclass
class TensorArg:
name: str
buffer: str
dtype: torch.dtype
offset: sympy.Expr = sympy.Integer(0)
@dataclasses.dataclass
class SizeArg:
name: str
expr: sympy.Expr
@dataclasses.dataclass
class DeviceCodegen:
scheduling: Any
wrapper_codegen: type
cpp_wrapper_codegen: type = type(None)
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
device_codegens: Dict[str, DeviceCodegen] = {}
class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError
def set_device(self, device_idx):
raise NotImplementedError
def synchronize(self):
raise NotImplementedError
def device_guard(self, device_idx):
raise NotImplementedError
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
device: str,
device_scheduling: type,
device_wrapper_codegen: type,
device_cpp_wrapper_codegen: type = type(None),
):
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
)
def get_scheduling_for_device(device: str):
return device_codegens[device].scheduling if device in device_codegens else None
def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
if device in device_codegens:
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
return (
wrapper_codegen_obj.cpp_wrapper_codegen
if cpp_wrapper
else wrapper_codegen_obj.wrapper_codegen
)
else:
return None
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
from ..ir import FlexibleLayout
# added contiguous index prevents reordering
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
device_op_overrides_dict[device] = device_op_overrides
def get_device_op_overrides(device: str):
assert isinstance(device, str)
if not device_op_overrides_dict.keys():
from .cuda import device_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
if device in device_op_overrides_dict.keys():
return device_op_overrides_dict[device]
@functools.lru_cache(None)
def boolean_ops():
return (
"is_inf",
"is_nan",
"bitwise_xor",
"logical_not",
"signbit",
"le",
"lt",
"ge",
"gt",
"eq",
"ne",
)
DTYPE_TO_COMPUTATION_DTYPE = {
torch.bfloat16: torch.float,
torch.float16: torch.float,
**{
dtype: dtype
for dtype in [
torch.bool,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
]
},
}
class DataTypePropagation:
def __init__(self, body) -> None:
self.body = body
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
"root": body.root_block.graph
}
for k, v in body.subblocks.items():
self.graphs[k] = v.graph
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
inputs = node.all_input_nodes
input_nodes = [
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
]
if len(input_nodes) == 0:
return None
all_input_nodes_propagated = all(
OptimizationContext.key in n.meta
and n.meta[OptimizationContext.key].dtype is not None
for n in input_nodes
)
if not all_input_nodes_propagated:
return None
return functools.reduce(
torch.promote_types,
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
)
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
sub_graph = self.graphs[node.target]
dtype = self.propagate_graph(sub_graph)
assert dtype
return dtype
def deduce_node_dtype(self, node: torch.fx.Node):
if node.target in boolean_ops():
return torch.bool
if node.op == "placeholder":
return None
if node.target == "output":
# we can infer output node if it only have 1 arg
if len(node.args) != 1:
return None
if node.target in (
"to_dtype",
"index_expr",
):
return node.args[-1]
if node.target in (
"rand",
"randn",
):
return torch.float
if node.target in (
"get_index",
"index_expr",
"randint64",
):
return torch.int64
if node.target in (
"load",
"store",
"store_reduction",
):
buf_name = node.args[1]
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
assert isinstance(node.target, str)
if node.target == "reduction":
return node.args[1]
if node.target == "constant":
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
return self.deduce_node_dtype_by_inputs(node)
def propagate_graph(self, graph: torch.fx.Graph):
assert graph.nodes
graph_dtype = None
# For masked_subblock, we use output's dtype to represent
# the dtype of this subgraph. For other cases, graph_dtype
# might be None
for node in graph.nodes:
if OptimizationContext.key in node.meta:
opt_ctx = node.meta[OptimizationContext.key]
else:
opt_ctx = OptimizationContext()
opt_ctx.dtype = self.deduce_node_dtype(node)
node.meta[OptimizationContext.key] = opt_ctx
if node.target == "output":
graph_dtype = opt_ctx.dtype
return graph_dtype
def propagate(self):
self.propagate_graph(self.graphs["root"])
@classmethod
def propagate_loopbody(cls, body):
return cls(body).propagate()
@classmethod
def propagate_scheduler_node(cls, node):
from ..ir import LoopBody
from ..scheduler import SchedulerNode
assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
DataTypePropagation.propagate_loopbody(node._body)
class ExprPrinter(Printer):
@staticmethod
def paren(string):
def all_in_parens(string):
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
if (
isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.I)
or re.match(r"^\([^)]*\)$", string, re.I)
or string == ""
):
return string
# don't put extra parens for strings that are already wrapped in parens
if all_in_parens(string):
return string
return f"({string})"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mul(self, expr):
return "*".join(map(self.paren, map(self._print, expr.args)))
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr)
def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
class PythonPrinter(ExprPrinter):
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return self._helper_sqrt(expr.args[0])
def _print_Pow(self, expr):
# Pow() confuses triton
base, exp = expr.args
# NB: Remember this is sizevar computation! You don't typically
# expect to have to do floating point computation including exponents
# in sizevar compute. Instead of adding support for floating
# point pow, you should make upstream retranslate the Sympy expression
# into Tensor expressions earlier and do that instead.
if exp == 0.5:
return self._helper_sqrt(base)
elif exp == -0.5:
return "1/" + self._helper_sqrt(base)
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
if exp > 0:
return "*".join([self.paren(base)] * exp)
elif exp < 0:
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
return "1"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_Trunc(self, expr):
assert len(expr.args) == 1
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_Round(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
def __getattr__(self, item):
return getattr(self._parent, item)
@staticmethod
def identity(value):
# used to trigger cse
return value
@staticmethod
def constant(value, dtype):
return repr(value)
@staticmethod
def reciprocal(x):
return ops.truediv(ops.constant(1, torch.int32), x)
@staticmethod
def square(x):
return ops.mul(x, x)
@staticmethod
def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}"
@staticmethod
def logical_not(a):
return f"{ExprPrinter.paren(a)} == 0"
@staticmethod
def bitwise_and(x, y):
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_or(x, y):
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_xor(x, y):
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_left_shift(x, y):
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_right_shift(x, y):
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))
@classmethod
def _initialize_pointwise_overrides(cls, target):
assert target in {"triton", "cpp", "cppvec"}, target
for funcname, data in pointwise_overrides_data.items():
impl = getattr(data, target)
if impl is None:
continue
setattr(cls, funcname, staticmethod(impl))
@dataclasses.dataclass
class OverridesData:
name: str
cpp: Callable[..., str]
# None when not impl in libdevice/triton
triton: Optional[Callable[..., str]] = None
# None when not impl in aten/.../vec
cppvec: Optional[Callable[..., str]] = None
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# NB: if you add a new special function, don't forget to update
# torch._inductor.ops_handler too
pointwise_overrides_data: Dict[str, OverridesData] = dict(
airy_ai=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"airy_ai_forward({x})",
name="special_airy_ai",
),
bessel_j0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_j0_forward({x})",
triton=lambda x: f"libdevice.j0({x})",
name="special_bessel_j0",
),
bessel_j1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_j1_forward({x})",
triton=lambda x: f"libdevice.j1({x})",
name="special_bessel_j1",
),
bessel_y0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_y0_forward({x})",
triton=lambda x: f"libdevice.y0({x})",
name="special_bessel_y0",
),
bessel_y1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_y1_forward({x})",
triton=lambda x: f"libdevice.y1({x})",
name="special_bessel_y1",
),
digamma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_digamma({x})",
cppvec=lambda x: f"{x}.digamma()",
name="digamma",
),
# no cpp nor triton implementation for entr, it is defined as decomposition
# erf, erfc
erfcx=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_erfcx({x})",
triton=lambda x: f"libdevice.erfcx({x})",
name="special_erfcx",
),
fma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
name="fma",
),
# erfinv, exp2, expit, gammaln
igamma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igamma({x}, {y})",
name="igamma",
),
igammac=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igammac({x}, {y})",
name="igammac",
),
gammainc=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igamma({x}, {y})",
name="special_gammainc",
),
gammaincc=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igammac({x}, {y})",
name="special_gammaincc",
),
i0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i0({x})",
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
cppvec=lambda x: f"{x}.i0()",
name="i0",
),
i0e=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i0e({x})",
cppvec=lambda x: f"{x}.i0e()",
name="special_i0e",
),
i1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i1({x})",
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
name="special_i1",
),
i1e=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i1e({x})",
name="special_i1e",
),
log_ndtr=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_log_ndtr({x})",
name="special_log_ndtr",
),
# logit
modified_bessel_i0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_i0_forward({x})",
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
name="special_modified_bessel_i0",
),
modified_bessel_i1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_i1_forward({x})",
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
name="special_modified_bessel_i1",
),
modified_bessel_k0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_k0_forward({x})",
name="special_modified_bessel_k0",
),
modified_bessel_k1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_k1_forward({x})",
name="special_modified_bessel_k1",
),
# multigamma
ndtr=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_ndtr({x})",
name="special_ndtr",
),
ndtri=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_ndtri({x})",
name="special_ndtri",
),
polygamma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_polygamma({y}, {x})",
name="polygamma",
),
# psi - alias to digamma
# round
scaled_modified_bessel_k0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
name="special_scaled_modified_bessel_k0",
),
scaled_modified_bessel_k1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
name="special_scaled_modified_bessel_k1",
),
# sinc
spherical_bessel_j0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"spherical_bessel_j0_forward({x})",
name="special_spherical_bessel_j0",
),
zeta=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"zeta({x}, {y})",
name="special_zeta",
),
chebyshev_polynomial_t=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
name="special_chebyshev_polynomial_t",
),
chebyshev_polynomial_u=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
name="special_chebyshev_polynomial_u",
),
chebyshev_polynomial_v=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
name="special_chebyshev_polynomial_v",
),
chebyshev_polynomial_w=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
name="special_chebyshev_polynomial_w",
),
legendre_polynomial_p=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
name="special_legendre_polynomial_p",
),
shifted_chebyshev_polynomial_t=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
name="special_shifted_chebyshev_polynomial_t",
),
shifted_chebyshev_polynomial_u=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
name="special_shifted_chebyshev_polynomial_u",
),
shifted_chebyshev_polynomial_v=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
name="special_shifted_chebyshev_polynomial_v",
),
shifted_chebyshev_polynomial_w=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
name="special_shifted_chebyshev_polynomial_w",
),
hermite_polynomial_h=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
name="special_hermite_polynomial_h",
),
hermite_polynomial_he=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
name="special_hermite_polynomial_he",
),
laguerre_polynomial_l=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
name="special_laguerre_polynomial_l",
),
)
# Use mypy to check protocol implemented correctly
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
return h
class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
def __init__(self, name, line):
super().__init__(line)
self.name = name
assert not isinstance(line, DeferredLineBase)
def __call__(self):
if all(
self.name not in x
for x in (
V.graph.removed_buffers,
V.kernel.removed_buffers,
V.graph.inplaced_to_remove,
V.kernel.inplaced_to_remove,
)
):
return self.line
return None
def _new_line(self, line):
return DeferredLine(self.name, line)
class BracesBuffer(IndentedBuffer):
def indent(self, offset=1):
@contextlib.contextmanager
def ctx():
for _ in range(offset):
self.writeline("{")
self._indent += 1
for _ in range(-offset):
self._indent -= 1
self.writeline("}")
yield
for _ in range(-offset):
self.writeline("{")
self._indent += 1
for _ in range(offset):
self._indent -= 1
self.writeline("}")
return ctx()
class InplacedBuffer(NamedTuple):
inner_name: str
other_names: List[str]
class KernelArgs:
@staticmethod
def _lookup(prefix, odict, name):
assert isinstance(name, (str, sympy.Symbol))
if name not in odict:
odict[name] = f"{prefix}{len(odict)}"
return odict[name]
def __init__(self, sizevars=None):
self.input_buffers = dict()
self.output_buffers = dict()
self.inplace_buffers = dict()
self.sizevars = sizevars or dict()
self.workspace_arg = None
def __repr__(self):
return "KernelArgs({})".format(
", ".join(
map(
repr,
[
self.input_buffers,
self.output_buffers,
self.inplace_buffers,
self.sizevars,
],
)
)
)
def _buffer_is_marked_removed(self, name):
return isinstance(name, str) and name.startswith("REMOVED")
def input(self, name):
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.output_buffers:
return self.output_buffers[name]
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
if name.startswith("seed"):
return self._lookup("seed", self.input_buffers, name)
return self._lookup("in_ptr", self.input_buffers, name)
def output(self, name):
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
return self._lookup("out_ptr", self.output_buffers, name)
def make_inplace(self, input_name, output_name):
assert output_name not in self.inplace_buffers
if input_name in self.inplace_buffers:
buf = self.inplace_buffers[input_name]
buf.other_names.append(output_name)
self.inplace_buffers[output_name] = buf
else:
buf = InplacedBuffer(
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
[input_name, output_name],
)
self.inplace_buffers[input_name] = buf
self.inplace_buffers[output_name] = buf
def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
if self.workspace_arg is None:
self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
return "ws_ptr", 0
offset = self.workspace_arg.nbytes
zero_fill = zero_fill or self.workspace_arg.zero_fill
self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
return "ws_ptr", offset
def seed_offset(self, name, value):
if value in self.sizevars:
return self.sizevars[value]
if name in self.sizevars.values():
name = (
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
)
self.sizevars[value] = name
return name
def size(self, name):
if str(name) == "seed":
self.sizevars["seed"] = "seed"
return "seed"
return self._lookup("ks", self.sizevars, name)
def call_names(self):
return chain(
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
)
def wrap_ptr_arg(self, buf, dtype):
return buf
def wrap_size_arg(self, size):
return str(size)
def cpp_argdefs(self):
from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
call_args = []
arg_defs = []
arg_types = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
outer = inplaced.other_names[-1]
inner = inplaced.inner_name
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.input_buffers.items():
if outer in self.inplace_buffers:
continue
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"const {cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"const {cpp_dtype}*")
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
arg_defs.append(f"const {INDEX_TYPE} {inner}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
assert self.workspace_arg is None, "Workspace not supported on CPU "
return arg_defs, call_args, arg_types
def python_argdefs(self):
arg_defs = []
call_args = []
arg_types = []
precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
arg_defs.append(inplaced.inner_name)
call_args.append(inplaced.other_names[-1])
arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
precompile_args.append(
TensorArg(
name=inplaced.inner_name,
buffer=inplaced.other_names[-1],
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
)
)
for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items()
):
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
arg_defs.append(inner)
call_args.append(outer)
arg_types.append(V.graph.get_dtype(outer))
precompile_args.append(
TensorArg(
name=inner,
buffer=outer,
dtype=V.graph.get_dtype(outer),
)
)
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
arg_types.append(type(outer))
precompile_args.append(SizeArg(inner, outer))
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
if self.workspace_arg is not None:
arg_defs.append("ws_ptr")
call_args.append("workspace")
precompile_args.append(self.workspace_arg)
return arg_defs, call_args, precompile_args, arg_types
def aliases(self):
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
for other in inplaced.other_names:
if (
other in V.graph.inplaced_to_remove
or other in V.kernel.inplaced_to_remove
):
continue
if other in self.input_buffers:
yield self.input_buffers[other], inplaced.inner_name
if other in self.output_buffers:
yield self.output_buffers[other], inplaced.inner_name
def is_removed(self, name):
def _is_removed(name, buffers):
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
return _is_removed(name, self.output_buffers) and _is_removed(
name, self.inplace_buffers
)
# Includes inplace buffers, excludes removed buffers. Essentially,
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self):
live_outs = set()
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
live_outs.add(inplaced.other_names[-1])
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
live_outs.add(outer)
return live_outs
class CSEVariable:
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
To do so, the backends can simply overload `Kernel.create_cse_var`
The "CSEVariable.update_on_args" method gives you a hook for annotations
See example of TritonCSEVariable in triton.py
"""
def __init__(self, name, bounds: ValueRanges[Any]):
assert isinstance(bounds, ValueRanges)
self.name = name
self.bounds = bounds
self.use_count = 1 # track how many tims this expression is used
def __str__(self):
return self.name
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other) -> bool:
return type(other) == type(self) and other.name == self.name
def update_on_args(self, name, args, kwargs):
pass
class CppWrapperKernelArgs(KernelArgs):
def wrap_ptr_arg(self, buf, dtype):
from .cpp_utils import DTYPE_TO_CPP
if config.abi_compatible:
# In the abi_compatible model, we just return the buf here.
# We will form correct call args later in wrapper.generate_kernel_all.
return buf
else:
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
def wrap_size_arg(self, size):
return f"{size}"
class CSE:
"""Common subexpression elimination"""
def __init__(
self,
prefix="",
suffix="",
name_prefix="tmp",
iter_buffers=None,
store_cache=None,
reduction_cache=None,
varname_map=None,
):
self.prefix = prefix
self.suffix = suffix
self.cache = {}
self.name_prefix = name_prefix
self.store_cache = store_cache or {}
self.reduction_cache = reduction_cache or {}
self.iter_buffer_ids = iter_buffers or itertools.count()
self.invalidated_stores = set()
self.varname_map = varname_map or {}
def invalidate(self, keep_vars: Set[str]):
for name, tmp in list(self.store_cache.items()):
if tmp not in keep_vars:
del self.store_cache[name]
self.invalidated_stores.add(name)
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
def clone(self):
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
return CSE(
prefix=self.prefix,
suffix=self.suffix,
name_prefix=self.name_prefix,
iter_buffers=self.iter_buffer_ids,
store_cache=self.store_cache,
varname_map=self.varname_map,
)
def generate(
self,
buffer: IndentedBuffer,
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
*,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
write=True,
assignment=True,
) -> CSEVariable:
if isinstance(expr, OpsValue):
expr = expr.value
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
assert write or assignment
if isinstance(expr, CSEVariable):
# If the expressions were always created with all the information, we could
# assert expr.bounds == bounds, but sometimes the expression is created
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
expr.bounds = expr.bounds.tighten(bounds)
expr.use_count += 1
return expr
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
var = self.cache.get(cache_key, None)
if not var:
var = self.newvar(bounds) if assignment else None
self.cache[cache_key] = var
if write:
if V.kernel.current_node:
V.kernel.current_node.codegen_originating_info(
buffer, only_once=True
)
if isinstance(expr, IndentedBuffer):
if assignment:
buffer.writeline(f"{self.prefix}{var} =")
buffer.splice(expr)
buffer.writeline(self.suffix)
else:
if assignment:
line = f"{self.prefix}{var} = {expr}{self.suffix}"
else:
line = f"{expr}{self.suffix}"
buffer.writeline(line)
else:
var.bounds = var.bounds.tighten(bounds)
var.use_count += 1
return var
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
var = V.kernel.create_cse_var(var_name, bounds)
self.varname_map[var_name] = var
return var
class IndirectAssertLine(DeferredLineBase):
def __init__(self, line, indirect_assert, var, mask, size_map):
super().__init__(line)
self.var = var
self.mask = mask
self.indirect_assert = indirect_assert
self.size_map = size_map
def __call__(self):
size, size_str = self.size_map[(self.var, self.mask)]
# We assert if we've not been able to prove the bound
assert_min = (self.var.bounds.lower >= 0) != sympy.true
assert_max = (self.var.bounds.upper < size) != sympy.true
lower = None
upper = None
if not (assert_min or assert_max):
return None
elif assert_min and assert_max:
lower = "0"
upper = size_str
elif assert_min:
lower = "0"
else:
assert assert_max
upper = size_str
return self.line.format(
assert_line=self.indirect_assert(self.var, lower, upper, self.mask)
)
def _new_line(self, line):
return IndirectAssertLine(
line, self.indirect_assert, self.var, self.mask, self.size_map
)
class CodeGen:
def __init__(self):
super().__init__()
self.exit_stack = contextlib.ExitStack()
def __enter__(self):
self.exit_stack.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
class ScopedDict:
def __init__(self, original_dict):
self.original_dict = original_dict
self.new_items = {}
def __getitem__(self, key):
if key in self.new_items:
return self.new_items[key]
return self.original_dict[key]
def __setitem__(self, key, value):
self.new_items[key] = value
def __contains__(self, key):
return key in self.new_items or key in self.original_dict
def get(self, key, default=None):
if key in self.new_items:
return self.new_items[key]
return self.original_dict.get(key, default)
class Kernel(CodeGen):
newvar_prefix = ""
suffix = ""
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
# TODO: these look dead, but with all the getattr it's hard to tell...
load_format: None = None
store_format: None = None
def __init__(self, args=None, increase_kernel_count=True):
super().__init__()
if increase_kernel_count:
metrics.generated_kernel_count += 1
self.args = args or KernelArgs()
self.loads = IndentedBuffer()
self.compute = IndentedBuffer()
self.stores = IndentedBuffer()
self.num_load = 0
self.num_reduction = 0
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
self.must_keep_buffers = set()
self.store_buffer_names = set()
self._load_mask = None
# set in set_current_node
self.current_node = None
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
# Upper bounds for indirect_indexing and their str representation
# NB: None, None is never stored in map, but it is the assumed
# "not set" value for the dict
self.indirect_max_sizes: Dict[
Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
] = {}
self.removed_buffers = set()
self.inplaced_to_remove = set()
# key: the buffer to write
# value: the buffer to read and whose memory can be reused for
# the buffer specified by key
self.inplace_update_buffers = dict()
# Set minimum number of elements processed per thread.
self.min_elem_per_thread = 1
self.kernel_name = None
@contextlib.contextmanager
def set_current_node(self, node):
prior = self.current_node
self.current_node = node
self.node_to_bounds = node._body.bounds().get_bounds()
try:
yield
finally:
self.current_node = prior
@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
def scope_cse(cse):
new_cse = cse.clone()
new_cse.cache = ScopedDict(cse.cache)
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
new_cse.store_cache = ScopedDict(cse.store_cache)
return new_cse
if cb is None:
cb = lb
loads = self.loads
compute = self.compute
stores = self.stores
cse = self.cse
self.loads = lb
self.compute = cb
self.stores = sb
self.cse = scope_cse(cse)
try:
yield
finally:
self.loads = loads
self.compute = compute
self.stores = stores
self.cse = cse
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
raise NotImplementedError
def indirect_load(self, name: str, index: sympy.Expr):
"""A load the depends on an index we have read"""
prior = self.loads
try:
# put the load in the compute section as it might have deps
self.loads = self.compute
return self.load(name, index)
finally:
self.loads = prior
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
raise NotImplementedError
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
raise NotImplementedError
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
raise NotImplementedError
def scan(
self,
dtypes: Tuple[torch.dtype, ...],
combine_fn: Callable[
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
],
values: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
raise NotImplementedError
def bucketize(
self,
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
indexing_dtype: torch.dtype,
right: bool,
) -> CSEVariable:
"""
See [Note: Inductor bucketize op]
"""
raise NotImplementedError
@property
def assert_function(self) -> str:
raise NotImplementedError
def indirect_assert(self, var, lower, upper, mask=None):
if lower and upper:
# The conditions need to be in parens because of Python's operator precedence.
# It'd be less error-prone to use and/or/not, which is suported by triton
cond = f"({lower} <= {var}) & ({var} < {upper})"
cond_print = f"{lower} <= {var} < {upper}"
elif lower:
cond = f"{lower} <= {var}"
cond_print = cond
else:
assert upper
cond = f"{var} < {upper}"
cond_print = cond
if mask:
cond = f"({cond}) | ~{mask}"
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError
def __enter__(self):
# TODO: hoist this to top level
class CSEProxy:
self.name = "CSEProxy"
vr_analysis = ValueRangeAnalysis()
@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
def do_cse(v):
csevar = self.cse.generate(self.compute, v, bounds=bounds)
csevar.update_on_args(name, args, kwargs)
return csevar
return pytree.tree_map(do_cse, value)
return inner
@staticmethod
def _bound_variable(name, *args, **kwargs):
"""
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
"""
from ..select_algorithm import TritonTemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
return ValueRanges.unknown()
fx_node = V.interpreter.current_node
if fx_node.target == name:
assert isinstance(self.node_to_bounds, dict)
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
# These create lots of inner strings. We would need to compute the bounds at the ops
# We will also likely not get much from computing VRs on these nodes
if any(
s in fx_node.target
for s in ("set_indirect", "reduction", "scan")
):
return ValueRanges.unknown()
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
# If there is no FX bound but we know how to compute one we do so
assert not kwargs
def arg_to_bound(x):
if isinstance(x, CSEVariable):
return x.bounds
elif isinstance(x, sympy.Expr):
return bound_sympy(x)
else:
return x
arg_bounds = list(map(arg_to_bound, args))
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
else:
return ValueRanges.unknown()
@staticmethod
def indirect_indexing(
var: CSEVariable, size: sympy.Expr, check: bool = True
):
# Skip CSE since this doesn't return an expression
if var.bounds.lower < 0: # type: ignore[operator]
new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance(
size, sympy.Number
):
# Take the negative part of the bound and add size to it
# Then take union of that and the positive part
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
neg = var.bounds & ValueRanges(-sympy.oo, -1)
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0: # type: ignore[operator]
pos = var.bounds & ValueRanges(0, sympy.oo)
new_bounds = new_bounds | pos
stm = ops.add(var, self.rename_indexing(size))
# Mixed negative and non-negative
if var.bounds.upper >= 0: # type: ignore[operator]
lt = ops.lt(var, 0)
stm = ops.where(lt, stm, var)
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
new_var.update_on_args("index_wrap", (var,), {})
var = new_var
if self.generate_assert(check):
mask = self.load_mask(var)
# An assertion line may have been written already, if so just
# update the max size.
map_key = (var, mask)
existing_size, _ = self.indirect_max_sizes.get(
map_key, (None, None)
)
if existing_size is not None:
size = sympy.Min(size, existing_size)
else:
self.compute.writeline(
IndirectAssertLine(
"{assert_line}",
self.indirect_assert,
var,
mask,
self.indirect_max_sizes,
)
)
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
return parent_handler.indirect_indexing(var, size, check)
@staticmethod
def load(name: str, index: sympy.Expr) -> CSEVariable:
if name in self.cse.invalidated_stores:
# A load from an invalidated store requires us to
# keep the actual buffer around
V.kernel.must_keep_buffers.add(name)
if free_symbol_is_type(index, SymT.TMP):
return self.indirect_load(name, index)
store_cache = self.cse.store_cache
if name in store_cache:
return store_cache[name]
out = self.load(name, index)
# count load that is not in the store_cache, and also not in the
# cse cache.
if out.use_count == 1:
self.num_load += 1
return out
@staticmethod
def store(
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
self.store_buffer_names.add(name)
if mode is None:
self.cse.store_cache[name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store(name, index, value, mode=mode)
else:
return None # type: ignore[return-value]
@staticmethod
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
self.store_buffer_names.add(name)
self.cse.store_cache[name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store_reduction(name, index, value)
@staticmethod
def reduction(
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
self.num_reduction += 1
return self.reduction(dtype, src_dtype, reduction_type, value)
@staticmethod
def scan(
dtypes: Tuple[torch.dtype, ...],
combine_fn: Callable[
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
Tuple[CSEVariable, ...],
],
values: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
return self.scan(dtypes, combine_fn, values)
@staticmethod
def bucketize(
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
indexing_dtype: torch.dtype,
right: bool,
) -> CSEVariable:
"""
[Note: Inductor bucketize op]
Given values (tensor) and offsets_name (reference to the name of a 1D
tensor), calculate the bucket that each value belongs to.
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
Offsets must be non-decreasing or the result is undefined.
"""
return self.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
)
# Use mypy to check protocol implemented correctly
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
return h
super().__enter__()
assert self.overrides
parent_handler = self.overrides(V.get_ops_handler())
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Note that V.graph.scheduler can be None when codegening triton template
kernels.
"""
if V.graph.scheduler:
V.graph.scheduler.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)
def generate_assert(self, check):
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
def load_mask(self, var) -> str:
# only the triton kernel requires mask
return ""
def rename_indexing(self, index) -> sympy.Expr:
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = {
x: self.args.size(x)
for x in sorted_symbols
if symbol_is_type(
x,
(
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
),
)
}
return sympy_subs(index, replacements)
def create_cse_var(self, *args, **kwargs):
return CSEVariable(*args, **kwargs)
@dataclasses.dataclass
class OptimizationContext:
key: ClassVar[str] = "opt_ctx"
dtype: Optional[torch.dtype] = None
ops_name: str = ""
@functools.lru_cache(None)
def jinja2_env():
try:
import jinja2
return jinja2.Environment(
undefined=jinja2.StrictUndefined,
)
except ImportError:
return None
class KernelTemplate:
"""
Base class for defining kernel templates.
Children classes: TritonTemplate, CUDATemplate
"""
@staticmethod
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
lines = source.splitlines(True)
if len(lines) > 1:
lines[1:] = [
(" " * indents_spacing * num_indents) + line for line in lines[1:]
]
return "".join(lines)
@staticmethod
def _template_from_string(source):
env = jinja2_env()
if env is not None:
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
return env.from_string(source)
return None
@staticmethod
def _fake_get_dtype(fake_out):
_get_dtype_real = V.graph.get_dtype
def get_dtype(name):
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
def __init__(self, name: str):
self.name = name
def maybe_append_choice(self, choices, **kwargs):
"""
Maybe generates a new ChoiceCaller and appends it into existing choices.
choices: A list of ChoiceCallers.
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
"""
try:
choices.append(self.generate(**kwargs))
except NotImplementedError:
pass
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
"""
Generates a ChoiceCaller instance from the given arguments.
"""
raise NotImplementedError