mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.
The generated graph looks like this for the test `test_unspec_float_output`:
```
def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
l_x_ = L_x_
l_y_ = L_y_
# File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
add: "f32[3]" = l_x_ + 1; l_x_ = None
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
mul: "Sym(2*zf0)" = item * 2; item = None
scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None
return (add, scalar_tensor)
```
The ingredients:
* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
1812 lines
62 KiB
Python
1812 lines
62 KiB
Python
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import operator
|
|
import re
|
|
from itertools import chain
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
ClassVar,
|
|
Dict,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import sympy
|
|
from sympy.printing.printer import Printer
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
|
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
|
|
|
from .. import config, metrics
|
|
from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique
|
|
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
|
|
|
|
|
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
|
|
|
|
|
def data_type_logger(msg):
|
|
if schedule_log.isEnabledFor(logging.DEBUG):
|
|
schedule_log.debug("Data type propagation: %s", msg)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class WorkspaceArg:
|
|
"""A temporary buffer used for a single kernel, then discarded.
|
|
|
|
Not registered as a traditional buffer since there are no users,
|
|
so it would be dead code eliminated.
|
|
"""
|
|
|
|
nbytes: sympy.Expr
|
|
zero_fill: bool
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TensorArg:
|
|
name: str
|
|
buffer: str
|
|
dtype: torch.dtype
|
|
offset: sympy.Expr = sympy.Integer(0)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SizeArg:
|
|
name: str
|
|
expr: sympy.Expr
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DeviceCodegen:
|
|
scheduling: Any
|
|
wrapper_codegen: type
|
|
cpp_wrapper_codegen: type = type(None)
|
|
|
|
|
|
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
|
|
|
|
device_codegens: Dict[str, DeviceCodegen] = {}
|
|
|
|
|
|
class DeviceOpOverrides:
|
|
def import_get_raw_stream_as(self, name):
|
|
raise NotImplementedError
|
|
|
|
def set_device(self, device_idx):
|
|
raise NotImplementedError
|
|
|
|
def synchronize(self):
|
|
raise NotImplementedError
|
|
|
|
def device_guard(self, device_idx):
|
|
raise NotImplementedError
|
|
|
|
|
|
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
|
|
|
|
|
|
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
|
# For any new backend looking to integrate with Inductor, customization of these two main
|
|
# parts are necessary to generate its specific code.
|
|
#
|
|
# Kernel code generation is determined by different Scheduling. Consequently, a new
|
|
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
|
|
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
|
|
#
|
|
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
|
|
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
|
|
# and override specific member functions to create backend-specific Python wrapper code.
|
|
#
|
|
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
|
|
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
|
|
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
|
|
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
|
|
# register_backend_for_device, to equip a new backend at runtime.
|
|
#
|
|
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
|
|
# This backend can be used as a reference:
|
|
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
|
|
def register_backend_for_device(
|
|
device: str,
|
|
device_scheduling: type,
|
|
device_wrapper_codegen: type,
|
|
device_cpp_wrapper_codegen: type = type(None),
|
|
):
|
|
device_codegens[device] = DeviceCodegen(
|
|
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
|
|
)
|
|
|
|
|
|
def get_scheduling_for_device(device: str):
|
|
return device_codegens[device].scheduling if device in device_codegens else None
|
|
|
|
|
|
def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
|
|
if device in device_codegens:
|
|
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
|
|
return (
|
|
wrapper_codegen_obj.cpp_wrapper_codegen
|
|
if cpp_wrapper
|
|
else wrapper_codegen_obj.wrapper_codegen
|
|
)
|
|
else:
|
|
return None
|
|
|
|
|
|
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
|
|
from ..ir import FlexibleLayout
|
|
|
|
# added contiguous index prevents reordering
|
|
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
|
|
|
|
|
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
|
|
device_op_overrides_dict[device] = device_op_overrides
|
|
|
|
|
|
def get_device_op_overrides(device: str):
|
|
assert isinstance(device, str)
|
|
|
|
if not device_op_overrides_dict.keys():
|
|
from .cuda import device_op_overrides # noqa: F401
|
|
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
|
|
|
|
if device in device_op_overrides_dict.keys():
|
|
return device_op_overrides_dict[device]
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def boolean_ops():
|
|
return (
|
|
"is_inf",
|
|
"is_nan",
|
|
"bitwise_xor",
|
|
"logical_not",
|
|
"signbit",
|
|
"le",
|
|
"lt",
|
|
"ge",
|
|
"gt",
|
|
"eq",
|
|
"ne",
|
|
)
|
|
|
|
|
|
DTYPE_TO_COMPUTATION_DTYPE = {
|
|
torch.bfloat16: torch.float,
|
|
torch.float16: torch.float,
|
|
**{
|
|
dtype: dtype
|
|
for dtype in [
|
|
torch.bool,
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.uint8,
|
|
torch.uint16,
|
|
torch.uint32,
|
|
torch.uint64,
|
|
]
|
|
},
|
|
}
|
|
|
|
|
|
class DataTypePropagation:
|
|
def __init__(self, body) -> None:
|
|
self.body = body
|
|
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
|
|
"root": body.root_block.graph
|
|
}
|
|
for k, v in body.subblocks.items():
|
|
self.graphs[k] = v.graph
|
|
|
|
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
|
|
inputs = node.all_input_nodes
|
|
input_nodes = [
|
|
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
|
|
]
|
|
if len(input_nodes) == 0:
|
|
return None
|
|
|
|
all_input_nodes_propagated = all(
|
|
OptimizationContext.key in n.meta
|
|
and n.meta[OptimizationContext.key].dtype is not None
|
|
for n in input_nodes
|
|
)
|
|
if not all_input_nodes_propagated:
|
|
return None
|
|
|
|
return functools.reduce(
|
|
torch.promote_types,
|
|
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
|
)
|
|
|
|
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
|
|
sub_graph = self.graphs[node.target]
|
|
dtype = self.propagate_graph(sub_graph)
|
|
assert dtype
|
|
return dtype
|
|
|
|
def deduce_node_dtype(self, node: torch.fx.Node):
|
|
if node.target in boolean_ops():
|
|
return torch.bool
|
|
|
|
if node.op == "placeholder":
|
|
return None
|
|
|
|
if node.target == "output":
|
|
# we can infer output node if it only have 1 arg
|
|
if len(node.args) != 1:
|
|
return None
|
|
|
|
if node.target in (
|
|
"to_dtype",
|
|
"index_expr",
|
|
):
|
|
return node.args[-1]
|
|
|
|
if node.target in (
|
|
"rand",
|
|
"randn",
|
|
):
|
|
return torch.float
|
|
|
|
if node.target in (
|
|
"get_index",
|
|
"index_expr",
|
|
"randint64",
|
|
):
|
|
return torch.int64
|
|
|
|
if node.target in (
|
|
"load",
|
|
"store",
|
|
"store_reduction",
|
|
):
|
|
buf_name = node.args[1]
|
|
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
|
|
|
if node.target == operator.getitem:
|
|
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
|
|
|
assert isinstance(node.target, str)
|
|
|
|
if node.target == "reduction":
|
|
return node.args[1]
|
|
|
|
if node.target == "constant":
|
|
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
|
|
|
|
if node.target.startswith("masked_subblock"):
|
|
return self.deduce_node_dtype_by_subgraph(node)
|
|
|
|
return self.deduce_node_dtype_by_inputs(node)
|
|
|
|
def propagate_graph(self, graph: torch.fx.Graph):
|
|
assert graph.nodes
|
|
graph_dtype = None
|
|
# For masked_subblock, we use output's dtype to represent
|
|
# the dtype of this subgraph. For other cases, graph_dtype
|
|
# might be None
|
|
for node in graph.nodes:
|
|
if OptimizationContext.key in node.meta:
|
|
opt_ctx = node.meta[OptimizationContext.key]
|
|
else:
|
|
opt_ctx = OptimizationContext()
|
|
|
|
opt_ctx.dtype = self.deduce_node_dtype(node)
|
|
node.meta[OptimizationContext.key] = opt_ctx
|
|
if node.target == "output":
|
|
graph_dtype = opt_ctx.dtype
|
|
return graph_dtype
|
|
|
|
def propagate(self):
|
|
self.propagate_graph(self.graphs["root"])
|
|
|
|
@classmethod
|
|
def propagate_loopbody(cls, body):
|
|
return cls(body).propagate()
|
|
|
|
@classmethod
|
|
def propagate_scheduler_node(cls, node):
|
|
from ..ir import LoopBody
|
|
from ..scheduler import SchedulerNode
|
|
|
|
assert isinstance(node, SchedulerNode)
|
|
assert isinstance(node._body, LoopBody)
|
|
DataTypePropagation.propagate_loopbody(node._body)
|
|
|
|
|
|
class ExprPrinter(Printer):
|
|
@staticmethod
|
|
def paren(string):
|
|
def all_in_parens(string):
|
|
if string[0] != "(" or len(string) < 2:
|
|
return False
|
|
count = 1
|
|
for i, char in enumerate(string[1:]):
|
|
if char == "(":
|
|
count += 1
|
|
elif char == ")":
|
|
count -= 1
|
|
if count == 0 and i != len(string) - 2:
|
|
return False
|
|
assert count == 0
|
|
return True
|
|
|
|
if (
|
|
isinstance(string, CSEVariable)
|
|
or re.match(r"^[a-z0-9_.]+$", string, re.I)
|
|
or re.match(r"^\([^)]*\)$", string, re.I)
|
|
or string == ""
|
|
):
|
|
return string
|
|
# don't put extra parens for strings that are already wrapped in parens
|
|
if all_in_parens(string):
|
|
return string
|
|
return f"({string})"
|
|
|
|
def _print_Infinity(self, expr):
|
|
return "math.inf"
|
|
|
|
def _print_NegativeInfinity(self, expr):
|
|
return "-math.inf"
|
|
|
|
def _print_Relational(self, expr):
|
|
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Mul(self, expr):
|
|
return "*".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Add(self, expr):
|
|
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Mod(self, expr):
|
|
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
|
|
|
def _print_CleanDiv(self, expr):
|
|
return self._print_FloorDiv(expr)
|
|
|
|
def _print_GreaterThan(self, expr):
|
|
# GreaterThan: >=
|
|
# StrictlyGreaterThan: >
|
|
# Go figure...
|
|
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_align(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"align({self._print(expr.args[0])})"
|
|
|
|
|
|
class PythonPrinter(ExprPrinter):
|
|
def _print_ModularIndexing(self, expr):
|
|
x, div, mod = expr.args
|
|
x = self.paren(self.doprint(x))
|
|
div = self.paren(self.doprint(div))
|
|
mod = self.paren(self.doprint(mod))
|
|
if div != "1":
|
|
x = f"({x} // {div})"
|
|
return f"{x} % {mod}"
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
x, div = expr.args
|
|
x = self.paren(self.doprint(x))
|
|
div = self.paren(self.doprint(div))
|
|
return f"({x} // {div})"
|
|
|
|
def _helper_sqrt(self, expr):
|
|
return f"math.sqrt({self._print(expr)})"
|
|
|
|
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
|
return self._helper_sqrt(expr.args[0])
|
|
|
|
def _print_Pow(self, expr):
|
|
# Pow() confuses triton
|
|
base, exp = expr.args
|
|
# NB: Remember this is sizevar computation! You don't typically
|
|
# expect to have to do floating point computation including exponents
|
|
# in sizevar compute. Instead of adding support for floating
|
|
# point pow, you should make upstream retranslate the Sympy expression
|
|
# into Tensor expressions earlier and do that instead.
|
|
if exp == 0.5:
|
|
return self._helper_sqrt(base)
|
|
elif exp == -0.5:
|
|
return "1/" + self._helper_sqrt(base)
|
|
base = self._print(base)
|
|
assert exp == int(exp), exp
|
|
exp = int(exp)
|
|
if exp > 0:
|
|
return "*".join([self.paren(base)] * exp)
|
|
elif exp < 0:
|
|
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
|
else: # exp == 0
|
|
return "1"
|
|
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.floor({self._print(expr.args[0])})"
|
|
|
|
def _print_Trunc(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.trunc({self._print(expr.args[0])})"
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.ceil({self._print(expr.args[0])})"
|
|
|
|
def _print_Abs(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"abs({self._print(expr.args[0])})"
|
|
|
|
def _print_Max(self, expr):
|
|
assert len(expr.args) >= 2
|
|
return f"max({', '.join(map(self._print, expr.args))})"
|
|
|
|
def _print_Min(self, expr):
|
|
assert len(expr.args) >= 2
|
|
return f"min({', '.join(map(self._print, expr.args))})"
|
|
|
|
def _print_OpaqueUnaryFn_cos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.cos({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_cosh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.cosh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_acos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.acos({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.sin({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sinh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.sinh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_asin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.asin({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.tan({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tanh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.tanh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_atan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.atan({self._print(expr.args[0])})"
|
|
|
|
def _print_Round(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"round({self._print(expr.args[0])})"
|
|
|
|
def _print_RoundDecimal(self, expr):
|
|
assert len(expr.args) == 2
|
|
number, ndigits = expr.args
|
|
assert isinstance(ndigits, sympy.Integer)
|
|
return f"round({self._print(number)}, {ndigits})"
|
|
|
|
|
|
class OpOverrides:
|
|
def __init__(self, parent):
|
|
super().__init__()
|
|
self._parent = parent
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self._parent, item)
|
|
|
|
@staticmethod
|
|
def identity(value):
|
|
# used to trigger cse
|
|
return value
|
|
|
|
@staticmethod
|
|
def constant(value, dtype):
|
|
return repr(value)
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return ops.truediv(ops.constant(1, torch.int32), x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return ops.mul(x, x)
|
|
|
|
@staticmethod
|
|
def bitwise_not(x):
|
|
return f"~{ExprPrinter.paren(x)}"
|
|
|
|
@staticmethod
|
|
def logical_not(a):
|
|
return f"{ExprPrinter.paren(a)} == 0"
|
|
|
|
@staticmethod
|
|
def bitwise_and(x, y):
|
|
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_or(x, y):
|
|
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_xor(x, y):
|
|
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_left_shift(x, y):
|
|
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_right_shift(x, y):
|
|
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def remainder(a, b):
|
|
r = ops.mod(a, b)
|
|
cond = ops.and_(
|
|
ops.ne(r, ops.constant(0, torch.int32)),
|
|
ops.ne(ops.signbit(r), ops.signbit(b)),
|
|
)
|
|
return ops.where(cond, ops.add(r, b), r)
|
|
|
|
@staticmethod
|
|
def load_seed(name, offset):
|
|
return ops.load(name, sympy.Integer(offset))
|
|
|
|
@classmethod
|
|
def _initialize_pointwise_overrides(cls, target):
|
|
assert target in {"triton", "cpp", "cppvec"}, target
|
|
|
|
for funcname, data in pointwise_overrides_data.items():
|
|
impl = getattr(data, target)
|
|
if impl is None:
|
|
continue
|
|
setattr(cls, funcname, staticmethod(impl))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OverridesData:
|
|
name: str
|
|
cpp: Callable[..., str]
|
|
# None when not impl in libdevice/triton
|
|
triton: Optional[Callable[..., str]] = None
|
|
# None when not impl in aten/.../vec
|
|
cppvec: Optional[Callable[..., str]] = None
|
|
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
|
)
|
|
|
|
|
|
pointwise_overrides_data: Dict[str, OverridesData] = dict(
|
|
airy_ai=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"airy_ai_forward({x})",
|
|
name="special_airy_ai",
|
|
),
|
|
bessel_j0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"bessel_j0_forward({x})",
|
|
triton=lambda x: f"libdevice.j0({x})",
|
|
name="special_bessel_j0",
|
|
),
|
|
bessel_j1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"bessel_j1_forward({x})",
|
|
triton=lambda x: f"libdevice.j1({x})",
|
|
name="special_bessel_j1",
|
|
),
|
|
bessel_y0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"bessel_y0_forward({x})",
|
|
triton=lambda x: f"libdevice.y0({x})",
|
|
name="special_bessel_y0",
|
|
),
|
|
bessel_y1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"bessel_y1_forward({x})",
|
|
triton=lambda x: f"libdevice.y1({x})",
|
|
name="special_bessel_y1",
|
|
),
|
|
digamma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_digamma({x})",
|
|
cppvec=lambda x: f"{x}.digamma()",
|
|
name="digamma",
|
|
),
|
|
# no cpp nor triton implementation for entr, it is defined as decomposition
|
|
# erf, erfc
|
|
erfcx=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_erfcx({x})",
|
|
triton=lambda x: f"libdevice.erfcx({x})",
|
|
name="special_erfcx",
|
|
),
|
|
fma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
|
|
cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
|
|
triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
|
|
name="fma",
|
|
),
|
|
# erfinv, exp2, expit, gammaln
|
|
igamma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
|
name="igamma",
|
|
),
|
|
igammac=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
|
name="igammac",
|
|
),
|
|
gammainc=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
|
name="special_gammainc",
|
|
),
|
|
gammaincc=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
|
name="special_gammaincc",
|
|
),
|
|
i0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_i0({x})",
|
|
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
|
cppvec=lambda x: f"{x}.i0()",
|
|
name="i0",
|
|
),
|
|
i0e=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_i0e({x})",
|
|
cppvec=lambda x: f"{x}.i0e()",
|
|
name="special_i0e",
|
|
),
|
|
i1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_i1({x})",
|
|
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
|
name="special_i1",
|
|
),
|
|
i1e=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_i1e({x})",
|
|
name="special_i1e",
|
|
),
|
|
log_ndtr=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_log_ndtr({x})",
|
|
name="special_log_ndtr",
|
|
),
|
|
# logit
|
|
modified_bessel_i0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"modified_bessel_i0_forward({x})",
|
|
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
|
name="special_modified_bessel_i0",
|
|
),
|
|
modified_bessel_i1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"modified_bessel_i1_forward({x})",
|
|
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
|
name="special_modified_bessel_i1",
|
|
),
|
|
modified_bessel_k0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"modified_bessel_k0_forward({x})",
|
|
name="special_modified_bessel_k0",
|
|
),
|
|
modified_bessel_k1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"modified_bessel_k1_forward({x})",
|
|
name="special_modified_bessel_k1",
|
|
),
|
|
# multigamma
|
|
ndtr=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_ndtr({x})",
|
|
name="special_ndtr",
|
|
),
|
|
ndtri=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"calc_ndtri({x})",
|
|
name="special_ndtri",
|
|
),
|
|
polygamma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"calc_polygamma({y}, {x})",
|
|
name="polygamma",
|
|
),
|
|
# psi - alias to digamma
|
|
# round
|
|
scaled_modified_bessel_k0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
|
|
name="special_scaled_modified_bessel_k0",
|
|
),
|
|
scaled_modified_bessel_k1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
|
|
name="special_scaled_modified_bessel_k1",
|
|
),
|
|
# sinc
|
|
spherical_bessel_j0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x: f"spherical_bessel_j0_forward({x})",
|
|
name="special_spherical_bessel_j0",
|
|
),
|
|
zeta=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"zeta({x}, {y})",
|
|
name="special_zeta",
|
|
),
|
|
chebyshev_polynomial_t=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_t",
|
|
),
|
|
chebyshev_polynomial_u=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_u",
|
|
),
|
|
chebyshev_polynomial_v=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_v",
|
|
),
|
|
chebyshev_polynomial_w=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_w",
|
|
),
|
|
legendre_polynomial_p=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
|
|
name="special_legendre_polynomial_p",
|
|
),
|
|
shifted_chebyshev_polynomial_t=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_t",
|
|
),
|
|
shifted_chebyshev_polynomial_u=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_u",
|
|
),
|
|
shifted_chebyshev_polynomial_v=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_v",
|
|
),
|
|
shifted_chebyshev_polynomial_w=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_w",
|
|
),
|
|
hermite_polynomial_h=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
|
|
name="special_hermite_polynomial_h",
|
|
),
|
|
hermite_polynomial_he=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
|
|
name="special_hermite_polynomial_he",
|
|
),
|
|
laguerre_polynomial_l=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
|
|
name="special_laguerre_polynomial_l",
|
|
),
|
|
)
|
|
|
|
|
|
# Use mypy to check protocol implemented correctly
|
|
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
|
|
return h
|
|
|
|
|
|
class DeferredLine(DeferredLineBase):
|
|
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
|
|
|
def __init__(self, name, line):
|
|
super().__init__(line)
|
|
self.name = name
|
|
assert not isinstance(line, DeferredLineBase)
|
|
|
|
def __call__(self):
|
|
if all(
|
|
self.name not in x
|
|
for x in (
|
|
V.graph.removed_buffers,
|
|
V.kernel.removed_buffers,
|
|
V.graph.inplaced_to_remove,
|
|
V.kernel.inplaced_to_remove,
|
|
)
|
|
):
|
|
return self.line
|
|
return None
|
|
|
|
def _new_line(self, line):
|
|
return DeferredLine(self.name, line)
|
|
|
|
|
|
class BracesBuffer(IndentedBuffer):
|
|
def indent(self, offset=1):
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
for _ in range(offset):
|
|
self.writeline("{")
|
|
self._indent += 1
|
|
for _ in range(-offset):
|
|
self._indent -= 1
|
|
self.writeline("}")
|
|
yield
|
|
for _ in range(-offset):
|
|
self.writeline("{")
|
|
self._indent += 1
|
|
for _ in range(offset):
|
|
self._indent -= 1
|
|
self.writeline("}")
|
|
|
|
return ctx()
|
|
|
|
|
|
class InplacedBuffer(NamedTuple):
|
|
inner_name: str
|
|
other_names: List[str]
|
|
|
|
|
|
class KernelArgs:
|
|
@staticmethod
|
|
def _lookup(prefix, odict, name):
|
|
assert isinstance(name, (str, sympy.Symbol))
|
|
if name not in odict:
|
|
odict[name] = f"{prefix}{len(odict)}"
|
|
return odict[name]
|
|
|
|
def __init__(self, sizevars=None):
|
|
self.input_buffers = dict()
|
|
self.output_buffers = dict()
|
|
self.inplace_buffers = dict()
|
|
self.sizevars = sizevars or dict()
|
|
self.workspace_arg = None
|
|
|
|
def __repr__(self):
|
|
return "KernelArgs({})".format(
|
|
", ".join(
|
|
map(
|
|
repr,
|
|
[
|
|
self.input_buffers,
|
|
self.output_buffers,
|
|
self.inplace_buffers,
|
|
self.sizevars,
|
|
],
|
|
)
|
|
)
|
|
)
|
|
|
|
def _buffer_is_marked_removed(self, name):
|
|
return isinstance(name, str) and name.startswith("REMOVED")
|
|
|
|
def input(self, name):
|
|
if V.graph.scheduler:
|
|
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
|
assert name not in V.graph.removed_buffers, name
|
|
if name in self.output_buffers:
|
|
return self.output_buffers[name]
|
|
if name in self.inplace_buffers:
|
|
return self.inplace_buffers[name].inner_name
|
|
if name.startswith("seed"):
|
|
return self._lookup("seed", self.input_buffers, name)
|
|
return self._lookup("in_ptr", self.input_buffers, name)
|
|
|
|
def output(self, name):
|
|
if V.graph.scheduler:
|
|
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
|
assert name not in V.graph.removed_buffers, name
|
|
if name in self.inplace_buffers:
|
|
return self.inplace_buffers[name].inner_name
|
|
return self._lookup("out_ptr", self.output_buffers, name)
|
|
|
|
def make_inplace(self, input_name, output_name):
|
|
assert output_name not in self.inplace_buffers
|
|
if input_name in self.inplace_buffers:
|
|
buf = self.inplace_buffers[input_name]
|
|
buf.other_names.append(output_name)
|
|
self.inplace_buffers[output_name] = buf
|
|
else:
|
|
buf = InplacedBuffer(
|
|
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
|
|
[input_name, output_name],
|
|
)
|
|
self.inplace_buffers[input_name] = buf
|
|
self.inplace_buffers[output_name] = buf
|
|
|
|
def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
|
|
if self.workspace_arg is None:
|
|
self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
|
|
return "ws_ptr", 0
|
|
|
|
offset = self.workspace_arg.nbytes
|
|
zero_fill = zero_fill or self.workspace_arg.zero_fill
|
|
self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
|
|
return "ws_ptr", offset
|
|
|
|
def seed_offset(self, name, value):
|
|
if value in self.sizevars:
|
|
return self.sizevars[value]
|
|
if name in self.sizevars.values():
|
|
name = (
|
|
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
|
|
)
|
|
self.sizevars[value] = name
|
|
return name
|
|
|
|
def size(self, name):
|
|
if str(name) == "seed":
|
|
self.sizevars["seed"] = "seed"
|
|
return "seed"
|
|
return self._lookup("ks", self.sizevars, name)
|
|
|
|
def call_names(self):
|
|
return chain(
|
|
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
|
)
|
|
|
|
def wrap_ptr_arg(self, buf, dtype):
|
|
return buf
|
|
|
|
def wrap_size_arg(self, size):
|
|
return str(size)
|
|
|
|
def cpp_argdefs(self):
|
|
from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
|
|
|
|
call_args = []
|
|
arg_defs = []
|
|
arg_types = []
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
outer = inplaced.other_names[-1]
|
|
inner = inplaced.inner_name
|
|
dtype = V.graph.get_dtype(outer)
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"{cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"{cpp_dtype}*")
|
|
for outer, inner in self.input_buffers.items():
|
|
if outer in self.inplace_buffers:
|
|
continue
|
|
dtype = V.graph.get_dtype(outer)
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"const {cpp_dtype}*")
|
|
for outer, inner in self.output_buffers.items():
|
|
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
|
continue
|
|
dtype = V.graph.get_dtype(outer)
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"{cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"{cpp_dtype}*")
|
|
for outer, inner in self.sizevars.items():
|
|
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
|
call_args.append(self.wrap_size_arg(outer))
|
|
arg_types.append(f"const {INDEX_TYPE}")
|
|
if V.graph.wrapper_code:
|
|
V.graph.wrapper_code.ensure_size_computed(outer)
|
|
assert self.workspace_arg is None, "Workspace not supported on CPU "
|
|
return arg_defs, call_args, arg_types
|
|
|
|
def python_argdefs(self):
|
|
arg_defs = []
|
|
call_args = []
|
|
precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
arg_defs.append(inplaced.inner_name)
|
|
call_args.append(inplaced.other_names[-1])
|
|
precompile_args.append(
|
|
TensorArg(
|
|
name=inplaced.inner_name,
|
|
buffer=inplaced.other_names[-1],
|
|
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
|
|
)
|
|
)
|
|
for outer, inner in chain(
|
|
self.input_buffers.items(), self.output_buffers.items()
|
|
):
|
|
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
|
continue
|
|
arg_defs.append(inner)
|
|
call_args.append(outer)
|
|
precompile_args.append(
|
|
TensorArg(
|
|
name=inner,
|
|
buffer=outer,
|
|
dtype=V.graph.get_dtype(outer),
|
|
)
|
|
)
|
|
for outer, inner in self.sizevars.items():
|
|
arg_defs.append(inner)
|
|
call_args.append(outer)
|
|
precompile_args.append(SizeArg(inner, outer))
|
|
if V.graph.wrapper_code:
|
|
V.graph.wrapper_code.ensure_size_computed(outer)
|
|
if self.workspace_arg is not None:
|
|
arg_defs.append("ws_ptr")
|
|
call_args.append("workspace")
|
|
precompile_args.append(self.workspace_arg)
|
|
|
|
return arg_defs, call_args, precompile_args
|
|
|
|
def aliases(self):
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
for other in inplaced.other_names:
|
|
if (
|
|
other in V.graph.inplaced_to_remove
|
|
or other in V.kernel.inplaced_to_remove
|
|
):
|
|
continue
|
|
if other in self.input_buffers:
|
|
yield self.input_buffers[other], inplaced.inner_name
|
|
if other in self.output_buffers:
|
|
yield self.output_buffers[other], inplaced.inner_name
|
|
|
|
def is_removed(self, name):
|
|
def _is_removed(name, buffers):
|
|
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
|
|
|
|
return _is_removed(name, self.output_buffers) and _is_removed(
|
|
name, self.inplace_buffers
|
|
)
|
|
|
|
# Includes inplace buffers, excludes removed buffers. Essentially,
|
|
# after you do a call into this kernel, which buffers actually contain
|
|
# updated data? Modeled off of python_argdefs.
|
|
def live_output_buffers(self):
|
|
live_outs = set()
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
live_outs.add(inplaced.other_names[-1])
|
|
for outer, inner in self.output_buffers.items():
|
|
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
|
continue
|
|
live_outs.add(outer)
|
|
return live_outs
|
|
|
|
|
|
class CSEVariable:
|
|
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
|
To do so, the backends can simply overload `Kernel.create_cse_var`
|
|
The "CSEVariable.update_on_args" method gives you a hook for annotations
|
|
See example of TritonCSEVariable in triton.py
|
|
"""
|
|
|
|
def __init__(self, name, bounds: ValueRanges[Any]):
|
|
assert isinstance(bounds, ValueRanges)
|
|
self.name = name
|
|
self.bounds = bounds
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.name)
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return type(other) == type(self) and other.name == self.name
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
pass
|
|
|
|
|
|
class CppWrapperKernelArgs(KernelArgs):
|
|
def wrap_ptr_arg(self, buf, dtype):
|
|
from .cpp_utils import DTYPE_TO_CPP
|
|
|
|
if config.abi_compatible:
|
|
# In the abi_compatible model, we just return the buf here.
|
|
# We will form correct call args later in wrapper.generate_kernel_all.
|
|
return buf
|
|
else:
|
|
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
|
|
|
|
def wrap_size_arg(self, size):
|
|
return f"{size}"
|
|
|
|
|
|
class CSE:
|
|
"""Common subexpression elimination"""
|
|
|
|
def __init__(
|
|
self,
|
|
prefix="",
|
|
suffix="",
|
|
name_prefix="tmp",
|
|
iter_buffers=None,
|
|
store_cache=None,
|
|
reduction_cache=None,
|
|
varname_map=None,
|
|
):
|
|
self.prefix = prefix
|
|
self.suffix = suffix
|
|
self.cache = {}
|
|
self.name_prefix = name_prefix
|
|
self.store_cache = store_cache or {}
|
|
self.reduction_cache = reduction_cache or {}
|
|
self.iter_buffer_ids = iter_buffers or itertools.count()
|
|
self.invalidated_stores = set()
|
|
self.varname_map = varname_map or {}
|
|
|
|
def invalidate(self, keep_vars: Set[str]):
|
|
for name, tmp in list(self.store_cache.items()):
|
|
if tmp not in keep_vars:
|
|
del self.store_cache[name]
|
|
self.invalidated_stores.add(name)
|
|
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
|
|
|
|
def clone(self):
|
|
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
|
return CSE(
|
|
prefix=self.prefix,
|
|
suffix=self.suffix,
|
|
name_prefix=self.name_prefix,
|
|
iter_buffers=self.iter_buffer_ids,
|
|
store_cache=self.store_cache,
|
|
varname_map=self.varname_map,
|
|
)
|
|
|
|
def generate(
|
|
self,
|
|
buffer: IndentedBuffer,
|
|
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
|
|
*,
|
|
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
|
write=True,
|
|
assignment=True,
|
|
) -> CSEVariable:
|
|
if isinstance(expr, OpsValue):
|
|
expr = expr.value
|
|
|
|
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
|
|
assert write or assignment
|
|
if isinstance(expr, CSEVariable):
|
|
# If the expressions were always created with all the information, we could
|
|
# assert expr.bounds == bounds, but sometimes the expression is created
|
|
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
|
expr.bounds = expr.bounds.tighten(bounds)
|
|
return expr
|
|
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
|
|
var = self.cache.get(cache_key, None)
|
|
if not var:
|
|
var = self.newvar(bounds) if assignment else None
|
|
self.cache[cache_key] = var
|
|
if write:
|
|
if V.kernel.current_node:
|
|
V.kernel.current_node.codegen_originating_info(
|
|
buffer, only_once=True
|
|
)
|
|
if isinstance(expr, IndentedBuffer):
|
|
if assignment:
|
|
buffer.writeline(f"{self.prefix}{var} =")
|
|
buffer.splice(expr)
|
|
buffer.writeline(self.suffix)
|
|
else:
|
|
if assignment:
|
|
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
|
else:
|
|
line = f"{expr}{self.suffix}"
|
|
buffer.writeline(line)
|
|
else:
|
|
var.bounds = var.bounds.tighten(bounds)
|
|
|
|
return var
|
|
|
|
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
|
|
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
|
var = V.kernel.create_cse_var(var_name, bounds)
|
|
self.varname_map[var_name] = var
|
|
return var
|
|
|
|
|
|
class IndirectAssertLine(DeferredLineBase):
|
|
def __init__(self, line, indirect_assert, var, mask, size_map):
|
|
super().__init__(line)
|
|
self.var = var
|
|
self.mask = mask
|
|
self.indirect_assert = indirect_assert
|
|
self.size_map = size_map
|
|
|
|
def __call__(self):
|
|
size, size_str = self.size_map[(self.var, self.mask)]
|
|
|
|
# We assert if we've not been able to prove the bound
|
|
assert_min = (self.var.bounds.lower >= 0) != sympy.true
|
|
assert_max = (self.var.bounds.upper < size) != sympy.true
|
|
|
|
lower = None
|
|
upper = None
|
|
if not (assert_min or assert_max):
|
|
return None
|
|
elif assert_min and assert_max:
|
|
lower = "0"
|
|
upper = size_str
|
|
elif assert_min:
|
|
lower = "0"
|
|
else:
|
|
assert assert_max
|
|
upper = size_str
|
|
|
|
return self.line.format(
|
|
assert_line=self.indirect_assert(self.var, lower, upper, self.mask)
|
|
)
|
|
|
|
def _new_line(self, line):
|
|
return IndirectAssertLine(
|
|
line, self.indirect_assert, self.var, self.mask, self.size_map
|
|
)
|
|
|
|
|
|
class CodeGen:
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.exit_stack = contextlib.ExitStack()
|
|
|
|
def __enter__(self):
|
|
self.exit_stack.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
class ScopedDict:
|
|
def __init__(self, original_dict):
|
|
self.original_dict = original_dict
|
|
self.new_items = {}
|
|
|
|
def __getitem__(self, key):
|
|
if key in self.new_items:
|
|
return self.new_items[key]
|
|
return self.original_dict[key]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.new_items[key] = value
|
|
|
|
def __contains__(self, key):
|
|
return key in self.new_items or key in self.original_dict
|
|
|
|
def get(self, key, default=None):
|
|
if key in self.new_items:
|
|
return self.new_items[key]
|
|
return self.original_dict.get(key, default)
|
|
|
|
|
|
class Kernel(CodeGen):
|
|
newvar_prefix = ""
|
|
suffix = ""
|
|
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
|
# TODO: these look dead, but with all the getattr it's hard to tell...
|
|
load_format: None = None
|
|
store_format: None = None
|
|
|
|
def __init__(self, args=None, increase_kernel_count=True):
|
|
super().__init__()
|
|
if increase_kernel_count:
|
|
metrics.generated_kernel_count += 1
|
|
self.args = args or KernelArgs()
|
|
self.loads = IndentedBuffer()
|
|
self.compute = IndentedBuffer()
|
|
self.stores = IndentedBuffer()
|
|
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
|
self.must_keep_buffers = set()
|
|
self.store_buffer_names = set()
|
|
self._load_mask = None
|
|
# set in set_current_node
|
|
self.current_node = None
|
|
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
|
# Upper bounds for indirect_indexing and their str representation
|
|
# NB: None, None is never stored in map, but it is the assumed
|
|
# "not set" value for the dict
|
|
self.indirect_max_sizes: Dict[
|
|
Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
|
|
] = {}
|
|
|
|
self.removed_buffers = set()
|
|
self.inplaced_to_remove = set()
|
|
|
|
# key: the buffer to write
|
|
# value: the buffer to read and whose memory can be reused for
|
|
# the buffer specified by key
|
|
self.inplace_update_buffers = dict()
|
|
# Set minimum number of elements processed per thread.
|
|
self.min_elem_per_thread = 1
|
|
self.kernel_name = None
|
|
|
|
@contextlib.contextmanager
|
|
def set_current_node(self, node):
|
|
prior = self.current_node
|
|
self.current_node = node
|
|
self.node_to_bounds = node._body.bounds().get_bounds()
|
|
try:
|
|
yield
|
|
finally:
|
|
self.current_node = prior
|
|
|
|
@contextlib.contextmanager
|
|
def swap_buffers(self, lb, cb=None, sb=None):
|
|
def scope_cse(cse):
|
|
new_cse = cse.clone()
|
|
new_cse.cache = ScopedDict(cse.cache)
|
|
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
|
|
new_cse.store_cache = ScopedDict(cse.store_cache)
|
|
return new_cse
|
|
|
|
if cb is None:
|
|
cb = lb
|
|
loads = self.loads
|
|
compute = self.compute
|
|
stores = self.stores
|
|
cse = self.cse
|
|
self.loads = lb
|
|
self.compute = cb
|
|
self.stores = sb
|
|
self.cse = scope_cse(cse)
|
|
try:
|
|
yield
|
|
finally:
|
|
self.loads = loads
|
|
self.compute = compute
|
|
self.stores = stores
|
|
self.cse = cse
|
|
|
|
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
|
raise NotImplementedError
|
|
|
|
def indirect_load(self, name: str, index: sympy.Expr):
|
|
"""A load the depends on an index we have read"""
|
|
prior = self.loads
|
|
try:
|
|
# put the load in the compute section as it might have deps
|
|
self.loads = self.compute
|
|
return self.load(name, index)
|
|
finally:
|
|
self.loads = prior
|
|
|
|
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
|
raise NotImplementedError
|
|
|
|
def store(
|
|
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
def reduction(
|
|
self,
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
|
raise NotImplementedError
|
|
|
|
def scan(
|
|
self,
|
|
dtypes: Tuple[torch.dtype, ...],
|
|
combine_fn: Callable[
|
|
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
|
|
],
|
|
values: Tuple[CSEVariable, ...],
|
|
) -> Tuple[CSEVariable, ...]:
|
|
raise NotImplementedError
|
|
|
|
def bucketize(
|
|
self,
|
|
values: CSEVariable,
|
|
offsets_name: str,
|
|
offsets_size: sympy.Expr,
|
|
indexing_dtype: torch.dtype,
|
|
right: bool,
|
|
) -> CSEVariable:
|
|
"""
|
|
See [Note: Inductor bucketize op]
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def assert_function(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def indirect_assert(self, var, lower, upper, mask=None):
|
|
if lower and upper:
|
|
# The conditions need to be in parens because of Python's operator precedence.
|
|
# It'd be less error-prone to use and/or/not, which is suported by triton
|
|
cond = f"({lower} <= {var}) & ({var} < {upper})"
|
|
cond_print = f"{lower} <= {var} < {upper}"
|
|
elif lower:
|
|
cond = f"{lower} <= {var}"
|
|
cond_print = cond
|
|
else:
|
|
assert upper
|
|
cond = f"{var} < {upper}"
|
|
cond_print = cond
|
|
|
|
if mask:
|
|
cond = f"({cond}) | ~{mask}"
|
|
|
|
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
|
|
|
|
def index_to_str(self, index: sympy.Expr) -> str:
|
|
raise NotImplementedError
|
|
|
|
def __enter__(self):
|
|
# TODO: hoist this to top level
|
|
class CSEProxy:
|
|
self.name = "CSEProxy"
|
|
vr_analysis = ValueRangeAnalysis()
|
|
|
|
@staticmethod
|
|
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
|
def inner(*args, **kwargs):
|
|
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
|
|
|
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
|
|
|
def do_cse(v):
|
|
csevar = self.cse.generate(self.compute, v, bounds=bounds)
|
|
csevar.update_on_args(name, args, kwargs)
|
|
return csevar
|
|
|
|
return pytree.tree_map(do_cse, value)
|
|
|
|
return inner
|
|
|
|
@staticmethod
|
|
def _bound_variable(name, *args, **kwargs):
|
|
"""
|
|
If the variable comes from an FX node, we forward the bound we have already computed
|
|
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
|
"""
|
|
from ..select_algorithm import TritonTemplateKernel
|
|
|
|
if isinstance(V.kernel, TritonTemplateKernel):
|
|
return ValueRanges.unknown()
|
|
|
|
fx_node = V.interpreter.current_node
|
|
if fx_node.target == name:
|
|
assert isinstance(self.node_to_bounds, dict)
|
|
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
|
|
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
|
|
# These create lots of inner strings. We would need to compute the bounds at the ops
|
|
# We will also likely not get much from computing VRs on these nodes
|
|
if any(
|
|
s in fx_node.target
|
|
for s in ("set_indirect", "reduction", "scan")
|
|
):
|
|
return ValueRanges.unknown()
|
|
|
|
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
|
|
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
|
|
|
|
# If there is no FX bound but we know how to compute one we do so
|
|
assert not kwargs
|
|
|
|
def arg_to_bound(x):
|
|
if isinstance(x, CSEVariable):
|
|
return x.bounds
|
|
elif isinstance(x, sympy.Expr):
|
|
return bound_sympy(x)
|
|
else:
|
|
return x
|
|
|
|
arg_bounds = list(map(arg_to_bound, args))
|
|
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
|
|
else:
|
|
return ValueRanges.unknown()
|
|
|
|
@staticmethod
|
|
def indirect_indexing(
|
|
var: CSEVariable, size: sympy.Expr, check: bool = True
|
|
):
|
|
# Skip CSE since this doesn't return an expression
|
|
|
|
if var.bounds.lower < 0: # type: ignore[operator]
|
|
new_bounds = ValueRanges.unknown()
|
|
if var.bounds != ValueRanges.unknown() and isinstance(
|
|
size, sympy.Number
|
|
):
|
|
# Take the negative part of the bound and add size to it
|
|
# Then take union of that and the positive part
|
|
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
|
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
|
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
|
# We don't have a good way of representing the empty range
|
|
if var.bounds.upper >= 0: # type: ignore[operator]
|
|
pos = var.bounds & ValueRanges(0, sympy.oo)
|
|
new_bounds = new_bounds | pos
|
|
|
|
stm = ops.add(var, self.rename_indexing(size))
|
|
# Mixed negative and non-negative
|
|
if var.bounds.upper >= 0: # type: ignore[operator]
|
|
lt = ops.lt(var, 0)
|
|
stm = ops.where(lt, stm, var)
|
|
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
|
|
|
new_var.update_on_args("index_wrap", (var,), {})
|
|
var = new_var
|
|
|
|
if self.generate_assert(check):
|
|
mask = self.load_mask(var)
|
|
|
|
# An assertion line may have been written already, if so just
|
|
# update the max size.
|
|
map_key = (var, mask)
|
|
existing_size, _ = self.indirect_max_sizes.get(
|
|
map_key, (None, None)
|
|
)
|
|
if existing_size is not None:
|
|
size = sympy.Min(size, existing_size)
|
|
else:
|
|
self.compute.writeline(
|
|
IndirectAssertLine(
|
|
"{assert_line}",
|
|
self.indirect_assert,
|
|
var,
|
|
mask,
|
|
self.indirect_max_sizes,
|
|
)
|
|
)
|
|
|
|
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
|
|
return parent_handler.indirect_indexing(var, size, check)
|
|
|
|
@staticmethod
|
|
def load(name: str, index: sympy.Expr) -> CSEVariable:
|
|
if name in self.cse.invalidated_stores:
|
|
# A load from an invalidated store requires us to
|
|
# keep the actual buffer around
|
|
V.kernel.must_keep_buffers.add(name)
|
|
if free_symbol_is_type(index, SymT.TMP):
|
|
return self.indirect_load(name, index)
|
|
store_cache = self.cse.store_cache
|
|
if name in store_cache:
|
|
return store_cache[name]
|
|
return self.load(name, index)
|
|
|
|
@staticmethod
|
|
def store(
|
|
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
self.store_buffer_names.add(name)
|
|
if mode is None:
|
|
self.cse.store_cache[name] = value
|
|
if self.current_node:
|
|
for other_name in self.current_node.get_mutations():
|
|
self.cse.store_cache[other_name] = value
|
|
if name not in V.graph.removed_buffers:
|
|
return self.store(name, index, value, mode=mode)
|
|
else:
|
|
return None # type: ignore[return-value]
|
|
|
|
@staticmethod
|
|
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
|
self.store_buffer_names.add(name)
|
|
self.cse.store_cache[name] = value
|
|
if self.current_node:
|
|
for other_name in self.current_node.get_mutations():
|
|
self.cse.store_cache[other_name] = value
|
|
|
|
if name not in V.graph.removed_buffers:
|
|
return self.store_reduction(name, index, value)
|
|
|
|
@staticmethod
|
|
def reduction(
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
|
return self.reduction(dtype, src_dtype, reduction_type, value)
|
|
|
|
@staticmethod
|
|
def scan(
|
|
dtypes: Tuple[torch.dtype, ...],
|
|
combine_fn: Callable[
|
|
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
|
|
Tuple[CSEVariable, ...],
|
|
],
|
|
values: Tuple[CSEVariable, ...],
|
|
) -> Tuple[CSEVariable, ...]:
|
|
return self.scan(dtypes, combine_fn, values)
|
|
|
|
@staticmethod
|
|
def bucketize(
|
|
values: CSEVariable,
|
|
offsets_name: str,
|
|
offsets_size: sympy.Expr,
|
|
indexing_dtype: torch.dtype,
|
|
right: bool,
|
|
) -> CSEVariable:
|
|
"""
|
|
[Note: Inductor bucketize op]
|
|
|
|
Given values (tensor) and offsets_name (reference to the name of a 1D
|
|
tensor), calculate the bucket that each value belongs to.
|
|
|
|
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
|
|
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
|
|
|
|
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
|
|
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
|
|
|
|
Offsets must be non-decreasing or the result is undefined.
|
|
"""
|
|
return self.bucketize(
|
|
values, offsets_name, offsets_size, indexing_dtype, right
|
|
)
|
|
|
|
# Use mypy to check protocol implemented correctly
|
|
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
|
return h
|
|
|
|
super().__enter__()
|
|
assert self.overrides
|
|
parent_handler = self.overrides(V.get_ops_handler())
|
|
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
|
|
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""
|
|
Note that V.graph.scheduler can be None when codegening triton template
|
|
kernels.
|
|
"""
|
|
if V.graph.scheduler:
|
|
V.graph.scheduler.remove_kernel_local_buffers()
|
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
def generate_assert(self, check):
|
|
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
|
|
|
|
def load_mask(self, var) -> str:
|
|
# only the triton kernel requires mask
|
|
return ""
|
|
|
|
def rename_indexing(self, index) -> sympy.Expr:
|
|
# adds the necessary kernel args for index expressions
|
|
# and renames variables in index expressions to kernel arg names
|
|
if isinstance(index, (list, tuple)):
|
|
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
|
index = V.graph.sizevars.simplify(index)
|
|
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
|
replacements = {
|
|
x: self.args.size(x)
|
|
for x in sorted_symbols
|
|
if symbol_is_type(
|
|
x,
|
|
(
|
|
SymT.UNBACKED_INT,
|
|
SymT.SIZE,
|
|
SymT.PRECOMPUTED_SIZE,
|
|
),
|
|
)
|
|
}
|
|
return sympy_subs(index, replacements)
|
|
|
|
def create_cse_var(self, *args, **kwargs):
|
|
return CSEVariable(*args, **kwargs)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OptimizationContext:
|
|
key: ClassVar[str] = "opt_ctx"
|
|
|
|
dtype: Optional[torch.dtype] = None
|
|
ops_name: str = ""
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def jinja2_env():
|
|
try:
|
|
import jinja2
|
|
|
|
return jinja2.Environment(
|
|
undefined=jinja2.StrictUndefined,
|
|
)
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
class KernelTemplate:
|
|
"""
|
|
Base class for defining kernel templates.
|
|
|
|
Children classes: TritonTemplate, CUDATemplate
|
|
"""
|
|
|
|
@staticmethod
|
|
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
|
|
lines = source.splitlines(True)
|
|
if len(lines) > 1:
|
|
lines[1:] = [
|
|
(" " * indents_spacing * num_indents) + line for line in lines[1:]
|
|
]
|
|
return "".join(lines)
|
|
|
|
@staticmethod
|
|
def _template_from_string(source):
|
|
env = jinja2_env()
|
|
if env is not None:
|
|
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
|
return env.from_string(source)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _fake_get_dtype(fake_out):
|
|
_get_dtype_real = V.graph.get_dtype
|
|
|
|
def get_dtype(name):
|
|
if name == fake_out.get_name():
|
|
return fake_out.get_dtype()
|
|
return _get_dtype_real(name)
|
|
|
|
return get_dtype
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
def maybe_append_choice(self, choices, **kwargs):
|
|
"""
|
|
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
|
|
|
choices: A list of ChoiceCallers.
|
|
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
|
|
"""
|
|
|
|
try:
|
|
choices.append(self.generate(**kwargs))
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
|
|
"""
|
|
Generates a ChoiceCaller instance from the given arguments.
|
|
"""
|
|
|
|
raise NotImplementedError
|