mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Trying to keep main as clean of type errors as possible until we are able to swtich to just one checker. This adds suppressions for existing type errors on main. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166439 Approved by: https://github.com/Skylion007
1732 lines
62 KiB
Python
1732 lines
62 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import re
|
|
from collections import defaultdict
|
|
from math import inf
|
|
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch._logging
|
|
|
|
from ..._prims_common import is_integer_dtype
|
|
from ...utils._ordered_set import OrderedSet
|
|
from ...utils._sympy.functions import FloorDiv, ModularIndexing
|
|
from ...utils._sympy.symbol import symbol_is_type, SymT
|
|
from ...utils._sympy.value_ranges import ValueRanges
|
|
from .. import config, ir
|
|
from ..codecache import HalideCodeCache
|
|
from ..ir import get_reduction_combine_fn
|
|
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
|
from ..ops_handler import AddParenHandler
|
|
from ..runtime.hints import HalideInputSpec, HalideMeta
|
|
from ..utils import (
|
|
get_bounds_index_expr,
|
|
get_kernel_metadata,
|
|
parallel_num_threads,
|
|
sympy_index_symbol,
|
|
sympy_subs,
|
|
)
|
|
from ..virtualized import _ops as ops, V
|
|
from .common import (
|
|
BackendFeature,
|
|
CSEVariable,
|
|
DeferredLine,
|
|
IndentedBuffer,
|
|
KernelArgType,
|
|
OpOverrides,
|
|
PythonPrinter,
|
|
SizeArg,
|
|
TensorArg,
|
|
)
|
|
from .cpp import DTYPE_TO_CPP
|
|
from .cpp_utils import cexpr
|
|
from .simd import constant_repr, SIMDKernel, SIMDScheduling
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
from ..ops_handler import ReductionType, StoreMode
|
|
from ..shape_propagation import BlockShapeType
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def halide_constant(val):
|
|
if isinstance(val, int) and not (-2147483648 <= val <= 2147483647):
|
|
info = torch.iinfo(torch.int64)
|
|
if val == info.min:
|
|
return "hl.Int(64).min()"
|
|
if val == info.max:
|
|
return "hl.Int(64).max()"
|
|
return f"hl.i64({val!r})"
|
|
if isinstance(val, float):
|
|
return f"hl.f64({constant_repr(val)})"
|
|
return repr(val)
|
|
|
|
|
|
class Unsupported(RuntimeError):
|
|
def __init__(self, thing) -> None:
|
|
super().__init__(f"halide backend does not support: {thing}")
|
|
|
|
|
|
class HalidePrinter(PythonPrinter):
|
|
@staticmethod
|
|
def cast_index(expr):
|
|
return f"hl.cast({V.kernel.index_dtype}, {expr})"
|
|
|
|
@staticmethod
|
|
def cast_float(expr):
|
|
return f"hl.cast(hl.Float(32), {expr})"
|
|
|
|
def _print_Float(self, expr):
|
|
return f"hl.f32({expr})"
|
|
|
|
def _print_ToFloat(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.f32({self._print(expr.args[0])})"
|
|
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
|
|
|
|
_print_FloorToInt = _print_floor
|
|
|
|
def _print_Trunc(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
|
|
|
|
_print_TruncToInt = _print_Trunc
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.ceil({self._print(expr.args[0])})")
|
|
|
|
def _helper_sqrt(self, expr):
|
|
return f"hl.sqrt({self.cast_float(self._print(expr))})"
|
|
|
|
def _print_Where(self, expr):
|
|
c = self.doprint(expr.args[0])
|
|
p = self.doprint(expr.args[1])
|
|
q = self.doprint(expr.args[2])
|
|
return f"hl.select({c}, {p}, {q})"
|
|
|
|
def _print_Min(self, expr):
|
|
if len(expr.args) == 1:
|
|
return self._print(expr.args[0])
|
|
|
|
mid = len(expr.args) // 2
|
|
a = self._print(sympy.Min(*expr.args[:mid]))
|
|
b = self._print(sympy.Min(*expr.args[mid:]))
|
|
return f"hl.min({a}, {b})"
|
|
|
|
def _print_Max(self, expr):
|
|
if len(expr.args) == 1:
|
|
return self._print(expr.args[0])
|
|
|
|
mid = len(expr.args) // 2
|
|
a = self._print(sympy.Max(*expr.args[:mid]))
|
|
b = self._print(sympy.Max(*expr.args[mid:]))
|
|
|
|
return f"hl.max({a}, {b})"
|
|
|
|
def _print_Abs(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.abs({self._print(expr.args[0])})")
|
|
|
|
def _print_OpaqueUnaryFn_cos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.cos({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_cosh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.cosh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_acos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.acos({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.sin({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sinh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.sinh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_asin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.asin({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.tan({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tanh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.tanh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_atan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.atan({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_log2(self, expr):
|
|
raise NotImplementedError("log2")
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
if expr.is_integer:
|
|
return super()._print_FloorDiv(expr)
|
|
|
|
x, div = expr.args
|
|
x = self.cast_float(self.doprint(x))
|
|
div = self.cast_float(self.doprint(div))
|
|
return self.cast_index(f"hl.floor({x} / {div})")
|
|
|
|
def _print_Round(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.round({self._print(expr.args[0])})")
|
|
|
|
_print_RoundToInt = _print_Round
|
|
|
|
def _print_IntTrueDiv(self, expr):
|
|
a, b = expr.args
|
|
# force a cast to float
|
|
return f"({a}) / ({b}+hl.f32(0))"
|
|
|
|
def _print_RoundDecimal(self, expr):
|
|
val, n = expr.args
|
|
val = self._print(val)
|
|
n = int(n)
|
|
return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
|
|
|
|
|
|
texpr = HalidePrinter().doprint
|
|
pexpr = PythonPrinter().doprint
|
|
|
|
|
|
_halide_type = {
|
|
torch.bool: "hl.Bool()",
|
|
torch.bfloat16: "hl.BFloat(16)",
|
|
torch.float16: "hl.Float(16)",
|
|
torch.float32: "hl.Float(32)",
|
|
torch.float64: "hl.Float(64)",
|
|
torch.int8: "hl.Int(8)",
|
|
torch.int16: "hl.Int(16)",
|
|
torch.int32: "hl.Int(32)",
|
|
torch.int64: "hl.Int(64)",
|
|
torch.uint8: "hl.UInt(8)",
|
|
torch.uint16: "hl.UInt(16)",
|
|
torch.uint32: "hl.UInt(32)",
|
|
torch.uint64: "hl.UInt(64)",
|
|
}
|
|
|
|
|
|
def halide_type(dtype):
|
|
return _halide_type[dtype]
|
|
|
|
|
|
def halide_acc_type(dtype):
|
|
if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64:
|
|
dtype = torch.int32
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
dtype = torch.float32
|
|
return halide_type(dtype)
|
|
|
|
|
|
class HalideOverrides(OpOverrides):
|
|
@staticmethod
|
|
def to_dtype(
|
|
x,
|
|
dtype: torch.dtype,
|
|
src_dtype: Optional[torch.dtype] = None,
|
|
use_compute_types=True,
|
|
):
|
|
if dtype == torch.bool:
|
|
return f"({x} != 0)"
|
|
return f"hl.cast({halide_type(dtype)}, {x})"
|
|
|
|
@staticmethod
|
|
def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
|
|
if src_dtype in (torch.float16, torch.bfloat16):
|
|
x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32
|
|
line = f"hl.reinterpret({halide_type(dtype)}, {x})"
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
line = f"hl.cast(hl.Float(32), {line})"
|
|
return line
|
|
|
|
@classmethod
|
|
def constant(cls, value, dtype):
|
|
return cls.to_dtype(halide_constant(value), dtype)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return f"hl.abs({x})"
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
if not hasattr(x, "name"):
|
|
return f"hl.exp({x})"
|
|
return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})"
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return f"hl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
# return f"hl.min({a}, {b})" <== handles nan wrong
|
|
if not hasattr(a, "name"):
|
|
return f"hl.min({a}, {b})"
|
|
b = f"hl.cast({a.name}.type(), {b})"
|
|
return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})"
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
# return f"hl.max({a}, {b})" <== handles nan wrong
|
|
if not hasattr(a, "name"):
|
|
return f"hl.max({a}, {b})"
|
|
b = f"hl.cast({a.name}.type(), {b})"
|
|
return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})"
|
|
|
|
@staticmethod
|
|
def where(a, b, c):
|
|
if hasattr(b, "name"):
|
|
c = f"hl.cast({b.name}.type(), {c})"
|
|
return f"hl.select({a}, {b}, {c})"
|
|
|
|
@staticmethod
|
|
def cos(x):
|
|
return f"hl.cos({x})"
|
|
|
|
@staticmethod
|
|
def sin(x):
|
|
return f"hl.sin({x})"
|
|
|
|
@staticmethod
|
|
def lgamma(x):
|
|
raise Unsupported("lgamma")
|
|
|
|
@staticmethod
|
|
def erf(x):
|
|
return f"hl.erf({x})"
|
|
|
|
@staticmethod
|
|
def cosh(x):
|
|
return f"hl.cosh({x})"
|
|
|
|
@staticmethod
|
|
def sinh(x):
|
|
return f"hl.sinh({x})"
|
|
|
|
@staticmethod
|
|
def acos(x):
|
|
return f"hl.acos({x})"
|
|
|
|
@staticmethod
|
|
def acosh(x):
|
|
return f"hl.acosh({x})"
|
|
|
|
@staticmethod
|
|
def asin(x):
|
|
return f"hl.asin({x})"
|
|
|
|
@staticmethod
|
|
def asinh(x):
|
|
return f"hl.asinh({x})"
|
|
|
|
@staticmethod
|
|
def atan2(x, y):
|
|
return f"hl.atan2({x}, {y})"
|
|
|
|
@staticmethod
|
|
def atan(x):
|
|
return f"hl.atan({x})"
|
|
|
|
@staticmethod
|
|
def atanh(x):
|
|
return f"hl.atanh({x})"
|
|
|
|
@staticmethod
|
|
def copysign(x, y):
|
|
raise Unsupported("copysign")
|
|
|
|
@staticmethod
|
|
def erfinv(x):
|
|
raise Unsupported("erfinv")
|
|
|
|
@staticmethod
|
|
def hypot(x, y):
|
|
return f"hl.hypot({x}, {y})"
|
|
|
|
@staticmethod
|
|
def nextafter(x, y):
|
|
raise Unsupported("nextafter")
|
|
|
|
@staticmethod
|
|
def logical_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def logical_not(a):
|
|
return f"{a} == 0"
|
|
|
|
@staticmethod
|
|
def logical_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def logical_xor(a, b):
|
|
return f"({a} ^ {b})"
|
|
|
|
@staticmethod
|
|
def bitwise_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_not(a):
|
|
return f"~{a}"
|
|
|
|
@staticmethod
|
|
def bitwise_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_xor(a, b):
|
|
return f"{a} ^ {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_left_shift(a, b):
|
|
return f"{a} << {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_right_shift(a, b):
|
|
return f"{a} >> {b}"
|
|
|
|
@staticmethod
|
|
def rand(seed, offset):
|
|
return f"halide_helpers.rand({seed}, {offset})"
|
|
|
|
@staticmethod
|
|
def randn(seed, offset):
|
|
return f"halide_helpers.randn({seed}, {offset})"
|
|
|
|
@staticmethod
|
|
def randint64(seed, offset, low, high):
|
|
return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})"
|
|
|
|
@staticmethod
|
|
def load_seed(name, offset):
|
|
return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}"
|
|
|
|
@staticmethod
|
|
def rsqrt(x):
|
|
# return f"hl.fast_inverse_sqrt({x})" <== accuracy issues
|
|
return f"1./hl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def tan(x):
|
|
return f"hl.tan({x})"
|
|
|
|
@staticmethod
|
|
def tanh(x):
|
|
return f"hl.tanh({x})"
|
|
|
|
@staticmethod
|
|
def signbit(x):
|
|
return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0"
|
|
|
|
@staticmethod
|
|
def fmod(a, b):
|
|
# TODO(jansel): find a better way to do this, builtin % has wrong sign
|
|
return f"{a} - hl.trunc({a}/{b})*{b}"
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return f"hl.log({x})" # hl.fast_log fails accuracy
|
|
|
|
@staticmethod
|
|
def log2(x):
|
|
raise NotImplementedError("log2")
|
|
|
|
@staticmethod
|
|
def isinf(x):
|
|
# workaround https://github.com/halide/Halide/issues/8309
|
|
return f"hl.is_inf(hl.cast(hl.Float(32), {x}))"
|
|
|
|
@staticmethod
|
|
def isnan(x):
|
|
# workaround https://github.com/halide/Halide/issues/8309
|
|
return f"hl.is_nan(hl.cast(hl.Float(32), {x}))"
|
|
|
|
@staticmethod
|
|
def round(x):
|
|
return f"hl.round({x})"
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return f"hl.floor({x})"
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
return f"({a}) / ({b} + hl.f32(0))"
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
# TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work
|
|
return (
|
|
f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
|
|
)
|
|
|
|
@classmethod
|
|
def sign(cls, x):
|
|
left = ops.to_dtype(ops.lt("0", x), torch.int8)
|
|
right = ops.to_dtype(ops.lt(x, "0"), torch.int8)
|
|
sub = ops.sub(left, right)
|
|
return f"hl.cast({x.name}.type(), {sub})"
|
|
|
|
@staticmethod
|
|
def trunc(x):
|
|
return f"hl.trunc({x})"
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
# this causes crashes with floating point exception, see test_div_zero_dim_cpu
|
|
# return f"hl.div_round_to_zero({a}, {b})"
|
|
return (
|
|
f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
|
|
)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return f"hl.ceil({x})"
|
|
|
|
@staticmethod
|
|
def relu(x):
|
|
return f"hl.max({x}, 0)"
|
|
|
|
@classmethod
|
|
def index_expr(cls, expr, dtype):
|
|
index = V.kernel.prepare_indexing(expr)
|
|
var = V.kernel.genfunc(
|
|
V.kernel.index_to_str(index),
|
|
V.kernel.used_dims_from_index(index),
|
|
bounds=get_bounds_index_expr(expr),
|
|
)
|
|
if dtype not in (torch.int32, torch.int64):
|
|
return ops.to_dtype(var, dtype)
|
|
return var
|
|
|
|
@classmethod
|
|
def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True):
|
|
# TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow
|
|
index_var = ops.to_dtype(index_var, torch.int32)
|
|
index_var = ops.halide_clamp(index_var, size, check)
|
|
index_var.indirect_indexing_size = size
|
|
return sympy_index_symbol(str(index_var))
|
|
|
|
@classmethod
|
|
def halide_clamp(cls, value, size, check):
|
|
end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1)
|
|
if not isinstance(size, (int, sympy.Integer)):
|
|
end = f"hl.cast({value.name}.type(), {end})"
|
|
# Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692
|
|
# return f"hl.unsafe_promise_clamped({value}, 0, {end})"
|
|
return f"hl.clamp({value}, 0, {end})"
|
|
|
|
@staticmethod
|
|
def masked(mask, body, other):
|
|
with V.kernel.mask_loads(mask, other) as new_mask:
|
|
result = body()
|
|
|
|
if result.bounds.is_bool:
|
|
other = bool(other)
|
|
|
|
# Take dtype from result to prevent accidental promotion
|
|
other = V.kernel.genfunc(
|
|
f"hl.cast({result.name}.type(), {halide_constant(other)})",
|
|
[],
|
|
bounds=ValueRanges.wrap(other),
|
|
shape=result.shape,
|
|
)
|
|
# TODO(jansel): look into removing the where in the same places triton does
|
|
return ops.where(new_mask, result, other)
|
|
|
|
@staticmethod
|
|
def frexp(x):
|
|
raise NotImplementedError("frexp")
|
|
|
|
@staticmethod
|
|
def device_assert_async(cond, msg):
|
|
raise NotImplementedError("device_assert_async")
|
|
|
|
@staticmethod
|
|
# pyrefly: ignore [bad-override]
|
|
def partial_accumulate(
|
|
name: str,
|
|
reduction_type: str,
|
|
value: CSEVariable,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
HalideOverrides._initialize_pointwise_overrides("halide")
|
|
|
|
|
|
class HalideCSEVariable(CSEVariable):
|
|
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
|
|
|
|
def __init__(
|
|
self,
|
|
name,
|
|
bounds: ValueRanges[Any],
|
|
dtype: Optional[torch.dtype] = None,
|
|
shape: BlockShapeType = None,
|
|
) -> None:
|
|
super().__init__(name, bounds, dtype, shape=shape)
|
|
self.used_dims: Optional[list[sympy.Symbol]] = None
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
used = OrderedSet(self.used_dims or ())
|
|
for arg in itertools.chain(args, kwargs.values()):
|
|
if isinstance(arg, HalideCSEVariable):
|
|
assert arg.used_dims is not None, (name, arg, args)
|
|
used.update(arg.used_dims)
|
|
self.used_dims = V.kernel.sort_used_dims(used)
|
|
|
|
def index_str(self, dims):
|
|
if len(dims) == 0:
|
|
return f"{self.name}[()]"
|
|
# Reversed since Halide is column major
|
|
return f"{self.name}[{', '.join(map(str, dims))}]"
|
|
|
|
def __str__(self) -> str:
|
|
if self.used_dims is None:
|
|
# This will get recomputed and replaced in codegen_kernel()
|
|
return f"{self.name}[?]"
|
|
return self.index_str(self.used_dims)
|
|
|
|
def subs_str(self, replacements):
|
|
assert self.used_dims is not None and all(
|
|
isinstance(x, sympy.Expr) for x in self.used_dims
|
|
)
|
|
return self.index_str([replacements.get(n, n) for n in self.used_dims])
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DimensionInfo:
|
|
expr: Optional[sympy.Expr]
|
|
size: sympy.Expr
|
|
stride: sympy.Expr
|
|
|
|
def __init__(self, expr, size, stride) -> None:
|
|
super().__init__()
|
|
if V.graph.sizevars.statically_known_lt(stride, 0):
|
|
stride = -stride
|
|
expr = -expr
|
|
self.expr = expr
|
|
self.size = size
|
|
self.stride = stride
|
|
|
|
def index_str(self, replacements=None, zero_vars=False):
|
|
assert self.expr is not None
|
|
expr = self.expr
|
|
if zero_vars and expr == 0:
|
|
return "hl.Var()"
|
|
if replacements:
|
|
replacements = {**replacements}
|
|
# pyrefly: ignore [missing-attribute]
|
|
for sym in expr.free_symbols:
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
assert isinstance(sym, sympy.Symbol)
|
|
var = V.kernel.lookup_cse_var(sym.name)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
replacements[sym] = sympy_index_symbol(var.subs_str(replacements))
|
|
expr = sympy_subs(expr, replacements)
|
|
return V.kernel.index_to_str(expr)
|
|
|
|
|
|
def eq(left, right):
|
|
if V.graph.sizevars.statically_known_equals(left, right):
|
|
return True
|
|
try:
|
|
a = V.graph.sizevars.size_hint_or_throw(left)
|
|
b = V.graph.sizevars.size_hint_or_throw(right)
|
|
except TypeError: # unbacked symints
|
|
return False
|
|
if a == b:
|
|
V.graph.sizevars.check_equals(left, right)
|
|
return a == b
|
|
|
|
|
|
def lt(left, right):
|
|
if V.graph.sizevars.statically_known_lt(left, right):
|
|
return True
|
|
try:
|
|
a = V.graph.sizevars.size_hint_or_throw(left)
|
|
b = V.graph.sizevars.size_hint_or_throw(right)
|
|
except TypeError: # unbacked symints
|
|
gcd = sympy.gcd(left, right)
|
|
if gcd == left:
|
|
return left != right
|
|
return False
|
|
if a < b:
|
|
V.graph.sizevars.check_lt(left, right)
|
|
return a < b
|
|
|
|
|
|
class HalideKernel(SIMDKernel):
|
|
overrides = HalideOverrides # type: ignore[assignment]
|
|
kexpr: Callable[[sympy.Expr], str] = texpr
|
|
|
|
def __init__(
|
|
self,
|
|
tiling: dict[str, sympy.Expr],
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(tiling, **kwargs)
|
|
# For halide, we just write directly to the body
|
|
self.compute = self.body
|
|
self.loads = self.body
|
|
self.stores = self.body
|
|
self.indexing_code_dom = IndentedBuffer()
|
|
self.needs_dom_indexing = self.inside_reduction
|
|
self.has_reduction = self.inside_reduction
|
|
self.buffer_dimensions: dict[str, list[DimensionInfo]] = {}
|
|
self.buffer_offsets: dict[str, sympy.Expr] = {}
|
|
# {h0: size1, h1: size2, ...}
|
|
self.halide_vars: dict[sympy.Symbol, sympy.Expr] = {}
|
|
# {x0: h0, x1: h1+10*h2, ...}
|
|
self.index_replacements: dict[sympy.Expr, sympy.Expr] = {}
|
|
# {h1: hr1, ...}
|
|
self.reduction_renames: dict[sympy.Symbol, sympy.Symbol] = {}
|
|
# {"i": {h0: hi0}, "o": ...}
|
|
self.dom_renames: dict[str, dict[sympy.Symbol, sympy.Symbol]] = {}
|
|
# {"in_ptr0": ["in_ptr0_view0"], ...}
|
|
self.buffer_aliases: dict[str, list[str]] = defaultdict(list)
|
|
self.has_indirect_indexing = False
|
|
|
|
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
|
return halide_type(dtype)
|
|
|
|
# pyrefly: ignore [bad-override]
|
|
def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
|
|
self.body.writeline(f"{name} = hl.Func({name!r})")
|
|
# pyrefly: ignore [bad-argument-type]
|
|
return HalideCSEVariable(name, bounds, dtype, shape)
|
|
|
|
def finalize_indexing(self, indices: Sequence[sympy.Expr]):
|
|
"""
|
|
Hook called right before codegen with every index that will be
|
|
used in the fused kernel.
|
|
|
|
This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing
|
|
scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex
|
|
we base indexing on a larger number of vars whose product combines to those.
|
|
|
|
This function populates self.halide_vars, self.index_replacements, and self.reduction_renames
|
|
"""
|
|
assert not (
|
|
self.index_replacements or self.halide_vars or self.reduction_renames
|
|
)
|
|
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
|
|
# pyrefly: ignore [bad-assignment]
|
|
indices = dict.fromkeys(map(super().prepare_indexing, indices))
|
|
all_used_symbols = OrderedSet[Any]()
|
|
sym_to_node = {
|
|
n.symbol(): n
|
|
for n in itertools.chain.from_iterable(
|
|
[tree.nodes.values() for tree in self.range_trees]
|
|
)
|
|
}
|
|
|
|
def simplify(expr):
|
|
return sympy.simplify(
|
|
V.graph.sizevars.remove_precomputed_replacements(expr)
|
|
)
|
|
|
|
def visit_modular_indexing(base, divisor, modulus):
|
|
if base in sym_to_node:
|
|
node = sym_to_node[base]
|
|
all_used_symbols.add(
|
|
node.root.lookup(
|
|
node.divisor * divisor,
|
|
V.graph.sizevars.evaluate_min(
|
|
modulus, FloorDiv(node.length, divisor)
|
|
),
|
|
).symbol()
|
|
)
|
|
|
|
def visit_floor_div(base, divisor):
|
|
if base in sym_to_node:
|
|
node = sym_to_node[base]
|
|
all_used_symbols.add(
|
|
node.root.lookup(
|
|
node.divisor * divisor,
|
|
FloorDiv(node.length, divisor),
|
|
).symbol()
|
|
)
|
|
|
|
# first figure out all_used_symbols to do dead symbol elimination
|
|
for index in indices:
|
|
if index.has(ModularIndexing):
|
|
index.replace(
|
|
ModularIndexing(
|
|
sympy.Wild("base"),
|
|
sympy.Wild("divisor"),
|
|
sympy.Wild("modulus"),
|
|
),
|
|
visit_modular_indexing,
|
|
)
|
|
if index.has(FloorDiv):
|
|
index.replace(
|
|
FloorDiv(
|
|
sympy.Wild("base"),
|
|
sympy.Wild("divisor"),
|
|
),
|
|
visit_floor_div,
|
|
)
|
|
all_used_symbols.update(super().prepare_indexing(index).free_symbols)
|
|
|
|
self.has_indirect_indexing = any(
|
|
symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols
|
|
)
|
|
|
|
had_fallback = False
|
|
for tree in reversed(self.range_trees):
|
|
nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols]
|
|
nodes.sort(key=lambda n: size_hint(n.divisor))
|
|
if not nodes:
|
|
nodes.append(tree.lookup(1, tree.numel))
|
|
handled_count = 0
|
|
divisor = sympy.S.One
|
|
added_sym_size = []
|
|
# decide on a minimal set of symbols and put them in self.halide_vars
|
|
while handled_count < len(nodes) and not eq(tree.numel, divisor):
|
|
sizes_to_add = [
|
|
simplify(n.length) for n in nodes if eq(n.divisor, divisor)
|
|
]
|
|
handled_count += len(sizes_to_add)
|
|
assert sizes_to_add, nodes
|
|
end = divisor * functools.reduce(
|
|
V.graph.sizevars.evaluate_max, sizes_to_add
|
|
)
|
|
sizes_to_add.extend(
|
|
[
|
|
simplify(n.divisor / divisor)
|
|
for n in nodes
|
|
if lt(divisor, n.divisor) and lt(n.divisor, end)
|
|
]
|
|
)
|
|
while sizes_to_add:
|
|
next_size = functools.reduce(sympy.gcd, sizes_to_add)
|
|
if eq(next_size, 1):
|
|
# sizes share no common factors, e.g [2, 21, 42, 441, 889056]
|
|
# TODO(jansel): we should just prevent fusion in cases that hit this
|
|
next_size = simplify(tree.numel / divisor)
|
|
assert not eq(next_size, 1)
|
|
sizes_to_add = []
|
|
handled_count = len(nodes)
|
|
had_fallback = True
|
|
sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
|
|
# pyrefly: ignore [missing-argument]
|
|
if tree.is_reduction:
|
|
self.reduction_renames[sym] = sympy_index_symbol(
|
|
f"hr{len(self.halide_vars)}"
|
|
)
|
|
self.halide_vars[sym] = next_size
|
|
added_sym_size.append((sym, next_size))
|
|
divisor *= next_size
|
|
new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)]
|
|
handled_count += len(new_sizes)
|
|
prior_len = len(sizes_to_add)
|
|
sizes_to_add = [
|
|
sympy.simplify(s / next_size)
|
|
for s in sizes_to_add
|
|
if not eq(s, next_size)
|
|
]
|
|
assert len(sizes_to_add) < prior_len or prior_len == 0
|
|
sizes_to_add.extend(new_sizes)
|
|
|
|
# create a mapping to the new set of symbols in self.index_replacements
|
|
for node in nodes:
|
|
try:
|
|
idx = 0
|
|
divisor = 1
|
|
while not eq(node.divisor, divisor):
|
|
sym, size = added_sym_size[idx]
|
|
idx += 1
|
|
divisor *= size
|
|
length = 1
|
|
expr = sympy.S.Zero
|
|
while not eq(node.length, length):
|
|
sym, size = added_sym_size[idx]
|
|
idx += 1
|
|
expr += length * sym
|
|
length *= size
|
|
self.index_replacements[node.symbol()] = expr
|
|
except IndexError:
|
|
assert had_fallback
|
|
full_index = sympy.S.Zero
|
|
stride = sympy.S.One
|
|
for sym, size in added_sym_size:
|
|
full_index += stride * sym
|
|
stride *= size
|
|
self.index_replacements[node.symbol()] = (
|
|
V.graph.sizevars.simplify_with_ranges(
|
|
ModularIndexing(full_index, node.divisor, node.length),
|
|
self.halide_vars, # type: ignore[arg-type]
|
|
)
|
|
)
|
|
|
|
# codegen the variable definitions
|
|
for sym in self.halide_vars:
|
|
self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})")
|
|
if self.reduction_renames:
|
|
self.codegen_rdom(
|
|
"rdom",
|
|
{rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()},
|
|
)
|
|
|
|
def setup_dom_indexing(self):
|
|
"""RDom based indexing uses explicit iteration ranges for Func updates"""
|
|
prefix = "i" if self.inside_reduction else "o"
|
|
if prefix in self.dom_renames:
|
|
return self.dom_renames[prefix]
|
|
|
|
renames = {}
|
|
for var in self.halide_vars.keys():
|
|
if not self.inside_reduction and var in self.reduction_renames:
|
|
continue
|
|
m = re.match(r"^h(\d+)$", var.name)
|
|
assert m
|
|
renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}")
|
|
|
|
self.codegen_rdom(
|
|
f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()}
|
|
)
|
|
|
|
self.dom_renames[prefix] = renames
|
|
return renames
|
|
|
|
def codegen_rdom(self, name, vars):
|
|
rsizes = [
|
|
f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})"
|
|
for size in vars.values()
|
|
]
|
|
self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])")
|
|
for i, rsym in enumerate(vars.keys()):
|
|
self.indexing_code.writeline(f"{rsym} = {name}[{i}]")
|
|
|
|
def prepare_indexing(
|
|
self,
|
|
index: sympy.Expr,
|
|
):
|
|
index = super().prepare_indexing(index)
|
|
index = sympy_subs(index, self.index_replacements)
|
|
return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type]
|
|
|
|
def sym_size(self, sym):
|
|
"""The size of an index symbol"""
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
return self.lookup_cse_var(sym.name).indirect_indexing_size
|
|
return self.halide_vars[sym]
|
|
|
|
def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool):
|
|
"""Convert address-based indexing into dimensions using self.halide_vars"""
|
|
symbols = []
|
|
for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined]
|
|
if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)):
|
|
symbols.append(sym)
|
|
else:
|
|
assert symbol_is_type(
|
|
sym,
|
|
(
|
|
SymT.UNBACKED_INT,
|
|
SymT.SIZE,
|
|
SymT.PRECOMPUTED_SIZE,
|
|
),
|
|
), sym
|
|
|
|
# group the expression by variables used
|
|
offset = sympy.S.Zero
|
|
split_expr = dict.fromkeys(symbols, sympy.S.Zero)
|
|
split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = []
|
|
index = sympy.expand(self.rename_indexing(index))
|
|
for part in index.args if isinstance(index, sympy.Add) else [index]:
|
|
part_vars = [v for v in part.free_symbols if v in split_expr]
|
|
if len(part_vars) == 0:
|
|
offset += part
|
|
elif len(part_vars) == 1:
|
|
split_expr[part_vars[0]] += part
|
|
else:
|
|
new_split_failed = []
|
|
for i in range(len(split_failed)):
|
|
assert split_failed[i] is not None
|
|
other_vars, other_part = split_failed[i]
|
|
if OrderedSet(other_vars) & OrderedSet(part_vars):
|
|
part_vars.extend([v for v in other_vars if v not in part_vars])
|
|
part += other_part
|
|
else:
|
|
new_split_failed.append((other_vars, other_part))
|
|
split_failed = [*new_split_failed, (part_vars, part)]
|
|
|
|
def expr_to_dimension(expr, syms):
|
|
expr = sympy.factor(expr)
|
|
if len(syms) == 1:
|
|
stride_wild = sympy.Wild("wild", exclude=symbols)
|
|
m = expr.match(stride_wild * syms[0])
|
|
if m:
|
|
return DimensionInfo(
|
|
syms[0], self.sym_size(syms[0]), m[stride_wild]
|
|
)
|
|
assert not is_store, expr
|
|
length = sympy.simplify(
|
|
sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1
|
|
)
|
|
stride = sympy.S.One
|
|
if isinstance(expr, sympy.Mul):
|
|
for term in expr.args:
|
|
if isinstance(term, sympy.Integer):
|
|
stride *= term
|
|
expr = sympy.simplify(expr / term)
|
|
length = sympy.simplify(sympy.ceiling(length / term))
|
|
return DimensionInfo(expr, length, stride)
|
|
|
|
# try to turn each group into a strided access
|
|
dims = []
|
|
for syms, expr in split_failed:
|
|
for v in syms:
|
|
expr += split_expr.pop(v)
|
|
dims.append(expr_to_dimension(expr, syms))
|
|
for sym, expr in split_expr.items():
|
|
dims.append(expr_to_dimension(expr, [sym]))
|
|
dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type]
|
|
|
|
if not dims: # scalar load/store
|
|
if self.has_indirect_indexing:
|
|
# workaround https://github.com/halide/Halide/issues/8338
|
|
dims.append(DimensionInfo(sympy.S.Zero, 1, 1))
|
|
elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1):
|
|
# Halide assumes dimension 0 is stride == 1, so add a dummy dimension
|
|
dims.insert(
|
|
0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1)
|
|
)
|
|
|
|
if dims and not is_store:
|
|
if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq(
|
|
offset, self.buffer_offsets[var]
|
|
):
|
|
# reuse the existing offset to avoid needing an input alias
|
|
self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var])
|
|
offset = self.buffer_offsets[var]
|
|
elif V.graph.sizevars.statically_known_gt(
|
|
offset, 0
|
|
): # TODO(jansel): negative offsets
|
|
# roll the offset into the dimensions for cleaner indexing
|
|
self.apply_offset_to_dimension(dims, offset)
|
|
offset = 0
|
|
|
|
orig_var = var
|
|
for i in itertools.count():
|
|
if self.install_dims(var, dims, offset, is_store):
|
|
return var, dims
|
|
assert not is_store
|
|
var = f"{orig_var}_view{i}"
|
|
if var not in self.buffer_aliases[orig_var]:
|
|
self.buffer_aliases[orig_var].append(var)
|
|
|
|
def install_dims(self, var, dims, offset, is_store):
|
|
"""Try to set self.buffer_dimensions[var], return True on success"""
|
|
if var not in self.buffer_dimensions:
|
|
self.buffer_dimensions[var] = dims
|
|
self.buffer_offsets[var] = offset
|
|
return True
|
|
if self.buffer_offsets[var] != offset or len(
|
|
self.buffer_dimensions[var]
|
|
) != len(dims):
|
|
return False
|
|
if is_store:
|
|
return self.buffer_dimensions[var] == dims
|
|
for old, new in zip(self.buffer_dimensions[var], dims):
|
|
if old.stride != new.stride:
|
|
return False
|
|
if old.size != new.size or old.expr != new.expr:
|
|
old.size = V.graph.sizevars.evaluate_max(old.size, new.size)
|
|
old.expr = None
|
|
return True
|
|
|
|
def apply_offset_to_dimension(self, dims, offset):
|
|
if offset == 0:
|
|
return
|
|
for i in reversed(range(len(dims))):
|
|
if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq(
|
|
offset, dims[i].stride
|
|
):
|
|
part = FloorDiv(offset, dims[i].stride)
|
|
offset -= part * dims[i].stride
|
|
dims[i].expr += part
|
|
assert offset == 0
|
|
|
|
def used_dims_from_index(self, index: sympy.Expr):
|
|
"""Detect which range trees are used to populate HalideCSEVariable.used_dims"""
|
|
used_dims = OrderedSet[sympy.Symbol]()
|
|
for sym in index.free_symbols:
|
|
assert isinstance(sym, sympy.Symbol)
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
# indirect indexing
|
|
cse_var = self.lookup_cse_var(sym.name)
|
|
assert (
|
|
isinstance(cse_var, HalideCSEVariable)
|
|
and cse_var.used_dims is not None
|
|
)
|
|
used_dims.update(cse_var.used_dims)
|
|
elif symbol_is_type(sym, SymT.HALIDE):
|
|
used_dims.add(sym)
|
|
elif symbol_is_type(
|
|
sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX)
|
|
):
|
|
pass
|
|
else:
|
|
raise NotImplementedError(f"unhandled symbol {sym}")
|
|
return self.sort_used_dims(used_dims)
|
|
|
|
def sort_used_dims(self, used_dims):
|
|
assert all(isinstance(x, sympy.Expr) for x in used_dims)
|
|
ordered = [
|
|
sym
|
|
for sym in itertools.chain(
|
|
self.halide_vars, self.reduction_renames.values()
|
|
)
|
|
if sym in used_dims
|
|
]
|
|
assert len(ordered) == len(used_dims)
|
|
return ordered
|
|
|
|
def make_index_str(self, dims, replacements=None, zero_vars=False):
|
|
index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims)
|
|
if len(dims) == 0:
|
|
index_str = "()"
|
|
elif len(dims) == 1:
|
|
# workaround for https://github.com/halide/Halide/issues/8299
|
|
index_str = f"{index_str},"
|
|
return index_str
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
"""Codegen a load from an InputBuffer"""
|
|
var = self.args.input(name)
|
|
index = self.prepare_indexing(index)
|
|
var, dims = self.indexing_to_dimensions(var, index, False)
|
|
line = f"{var}[{self.make_index_str(dims)}]"
|
|
dtype = V.graph.get_dtype(name)
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
dtype = torch.float32
|
|
line = f"hl.cast(hl.Float(32), {line})"
|
|
|
|
if self._load_mask:
|
|
assert (
|
|
isinstance(self._load_mask, HalideCSEVariable)
|
|
and self._load_mask.used_dims is not None
|
|
)
|
|
used_dims = OrderedSet(
|
|
(*self.used_dims_from_index(index), *self._load_mask.used_dims)
|
|
)
|
|
result = self.newfunc(self.sort_used_dims(used_dims))
|
|
if result.used_dims:
|
|
self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])")
|
|
self.body.writeline(f"{result.name}_mask.where({self._load_mask})")
|
|
other = self.kexpr(self._load_other or 0) # type: ignore[arg-type]
|
|
self.body.writeline(
|
|
f"{result} = hl.cast({halide_type(dtype)}, {other})"
|
|
)
|
|
self.body.writeline(
|
|
f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)"
|
|
)
|
|
else:
|
|
# scalar case
|
|
self.body.writeline(
|
|
f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))"
|
|
)
|
|
return result
|
|
else:
|
|
return self.genfunc(line, self.used_dims_from_index(index))
|
|
|
|
def lookup_cse_var(self, name: str):
|
|
return self.cse.varname_map[re.sub(r"\[.*", "", name)]
|
|
|
|
def store(
|
|
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
"""Codegen a store to an OutputBuffer"""
|
|
assert isinstance(value, HalideCSEVariable)
|
|
var = self.args.output(name)
|
|
index = self.prepare_indexing(index)
|
|
var, dims = self.indexing_to_dimensions(var, index, True)
|
|
if self.is_indirect_indexing(index) or mode is not None:
|
|
replacements = self.setup_dom_indexing()
|
|
index_str = self.make_index_str(dims, replacements)
|
|
value_str = value.subs_str(replacements)
|
|
undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()"
|
|
self.body.writeline(
|
|
DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())")
|
|
)
|
|
else:
|
|
index_str = self.make_index_str(dims, zero_vars=True)
|
|
value_str = str(value)
|
|
|
|
dtype = V.graph.get_dtype(name)
|
|
if mode is None:
|
|
line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})"
|
|
elif mode == "atomic_add":
|
|
line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})"
|
|
else:
|
|
raise NotImplementedError(f"store mode={mode}")
|
|
self.body.writeline(DeferredLine(name, line))
|
|
|
|
def reduction(
|
|
self,
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
|
|
"""Codegen a reduction operation"""
|
|
assert self.inside_reduction
|
|
assert not self._load_mask
|
|
cache_key = (src_dtype, reduction_type, value)
|
|
if cache_key in self.cse.reduction_cache:
|
|
return self.cse.reduction_cache[cache_key]
|
|
|
|
if isinstance(value, tuple):
|
|
assert reduction_type == "welford_combine"
|
|
self.cse.reduction_cache[cache_key] = result_tuple = (
|
|
self.welford_combine_impl(*value)
|
|
)
|
|
return result_tuple
|
|
|
|
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
|
reduction_vars = OrderedSet(self.reduction_renames)
|
|
result_var = self.newfunc(
|
|
[v for v in value.used_dims if v not in reduction_vars],
|
|
)
|
|
if reduction_vars - OrderedSet(value.used_dims):
|
|
value = self.genfunc(
|
|
f"{value}",
|
|
self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))),
|
|
shape=value.shape,
|
|
)
|
|
value_str = value.subs_str(self.reduction_renames)
|
|
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
|
|
acc_type = halide_acc_type(dtype)
|
|
|
|
if reduction_type in ("argmax", "argmin"):
|
|
index = f"{result_var.name}_{reduction_type}"
|
|
self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})")
|
|
# turn the N-D argmax index into a 1-D one
|
|
parts = []
|
|
stride = 1
|
|
for i, sym in enumerate(self.reduction_renames):
|
|
# pyrefly: ignore [bad-argument-type]
|
|
parts.append(f"{index}[{i}]")
|
|
if stride != 1:
|
|
# pyrefly: ignore [unsupported-operation]
|
|
parts[-1] += f"*{stride}"
|
|
stride *= self.halide_vars[sym]
|
|
self.body.writeline(f"{result_var} = {' + '.join(parts)}")
|
|
elif reduction_type == "welford_reduce":
|
|
# TODO(jansel): implement welford_reduce without fallback
|
|
result_var = self.welford_reduce_fallback(dtype, value)
|
|
else:
|
|
combine_fn = get_reduction_combine_fn(reduction_type, acc_type)
|
|
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
|
|
combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type]
|
|
default_str = f"hl.cast({acc_type}, {halide_constant(default)})"
|
|
self.body.writeline(f"{result_var} = {default_str}")
|
|
self.body.writeline(f"{result_var} = {combine_str}")
|
|
|
|
self.cse.reduction_cache[cache_key] = result_var
|
|
return result_var
|
|
|
|
def welford_combine_impl(self, mean, m2, weight):
|
|
assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None
|
|
assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None
|
|
assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None
|
|
used_dims = OrderedSet(
|
|
(*mean.used_dims, *m2.used_dims, *weight.used_dims) or self.halide_vars
|
|
)
|
|
used_dims -= OrderedSet(self.reduction_renames)
|
|
result_var = self.newfunc(self.sort_used_dims(used_dims))
|
|
default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)]
|
|
pfx = result_var.name
|
|
self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])")
|
|
self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]")
|
|
self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]")
|
|
self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]")
|
|
self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}")
|
|
self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}")
|
|
self.body.writeline(
|
|
f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}"
|
|
)
|
|
self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1")
|
|
self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2")
|
|
self.body.writeline(
|
|
f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)"
|
|
)
|
|
update = [
|
|
f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w",
|
|
f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w",
|
|
f"{pfx}_new_weight",
|
|
]
|
|
self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])")
|
|
|
|
unpacked = []
|
|
for i in range(3):
|
|
unpacked.append(self.newfunc(result_var.used_dims))
|
|
self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
|
|
return tuple(unpacked)
|
|
|
|
def scan(
|
|
self,
|
|
dtypes: tuple[torch.dtype, ...],
|
|
combine_fn: Callable[
|
|
[tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
|
|
],
|
|
values_orig: tuple[CSEVariable, ...],
|
|
) -> tuple[CSEVariable, ...]:
|
|
assert self.inside_reduction
|
|
assert len(dtypes) == len(values_orig)
|
|
values: list[HalideCSEVariable] = []
|
|
all_used_dims = OrderedSet[sympy.Symbol]()
|
|
|
|
for value in values_orig:
|
|
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
|
if OrderedSet(value.used_dims) & OrderedSet(self.reduction_renames):
|
|
values.append(value)
|
|
else:
|
|
values.append(
|
|
self.genfunc(
|
|
f"{value}",
|
|
[*value.used_dims, [*self.reduction_renames][:1]],
|
|
shape=value.shape,
|
|
)
|
|
)
|
|
all_used_dims.update(value.used_dims)
|
|
result_var = self.newfunc(self.sort_used_dims(all_used_dims))
|
|
assert result_var.used_dims and OrderedSet(result_var.used_dims) & OrderedSet(
|
|
self.reduction_renames
|
|
)
|
|
initial = [
|
|
f"hl.cast({halide_acc_type(dtype)}, {value})"
|
|
for dtype, value in zip(dtypes, values)
|
|
]
|
|
|
|
length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel))
|
|
scan_dom = f"{result_var.name}_rdom"
|
|
scan = f"{scan_dom}.x"
|
|
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
|
|
|
|
assert len(self.reduction_renames) == 1, (
|
|
"multi-dimensional scan not implemented"
|
|
)
|
|
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
|
|
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
|
|
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
|
|
|
|
if len(values) == 1:
|
|
|
|
def maybe_tuple(x):
|
|
return x[0]
|
|
|
|
read_left = [result_var.subs_str(scan_renames_pri)]
|
|
read_right = [result_var.subs_str(scan_renames_cur)]
|
|
else:
|
|
|
|
def maybe_tuple(x):
|
|
return f"hl.Tuple([{', '.join(x)}])"
|
|
|
|
read_left = [
|
|
result_var.subs_str(scan_renames_pri) + f"[{i}]"
|
|
for i in range(len(values))
|
|
]
|
|
read_right = [
|
|
result_var.subs_str(scan_renames_cur) + f"[{i}]"
|
|
for i in range(len(values))
|
|
]
|
|
|
|
self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")
|
|
|
|
# Disable CSE for update fn
|
|
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
|
|
combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type]
|
|
self.body.writeline(
|
|
f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}"
|
|
)
|
|
|
|
if len(values) == 1:
|
|
return (result_var,)
|
|
|
|
unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values]
|
|
for i, v in enumerate(unpack_vars):
|
|
self.body.writeline(f"{v} = {result_var}[{i}]")
|
|
return tuple(unpack_vars)
|
|
|
|
def genfunc(
|
|
self,
|
|
line,
|
|
used_dims,
|
|
*,
|
|
bounds=ValueRanges.unknown(),
|
|
shape: BlockShapeType = None,
|
|
) -> HalideCSEVariable:
|
|
var = self.cse.generate(self.body, line, bounds=bounds, shape=shape)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
var.used_dims = used_dims
|
|
return var
|
|
|
|
def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable:
|
|
var = self.cse.newvar(shape=shape)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
var.used_dims = used_dims
|
|
return var
|
|
|
|
def halide_buffer_numel(self, name: str):
|
|
"""
|
|
We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch
|
|
supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while
|
|
PyTorch's numel excludes them.
|
|
"""
|
|
return V.graph.get_buffer(name).get_layout().storage_size()
|
|
|
|
def halide_argdefs(self):
|
|
"""
|
|
Halide requires scalar inputs before outputs, so need to reorder args.
|
|
"""
|
|
|
|
def arg_order(arg_tuple):
|
|
_call_str, arg = arg_tuple
|
|
if isinstance(arg, SizeArg):
|
|
return 1 # this would normally be at the end, move it to middle
|
|
elif "out_ptr" in arg.name:
|
|
return 2
|
|
else:
|
|
assert "in_ptr" in arg.name
|
|
return 0
|
|
|
|
result: list[tuple[Optional[str], KernelArgType]] = []
|
|
_, a, b, _ = self.args.python_argdefs()
|
|
for call_str, arg in sorted(zip(a, b), key=arg_order):
|
|
result.append((call_str, arg))
|
|
if isinstance(arg, TensorArg):
|
|
assert arg.offset == 0 and arg.alias_of is None
|
|
result.extend(
|
|
(
|
|
None,
|
|
TensorArg(
|
|
alias,
|
|
arg.buffer,
|
|
arg.dtype,
|
|
arg.offset,
|
|
alias_of=arg.name,
|
|
),
|
|
)
|
|
for alias in self.buffer_aliases.get(arg.name, ())
|
|
)
|
|
return result
|
|
|
|
def halide_kernel_meta(self) -> HalideMeta:
|
|
"""Compute metadata required by codecache.py"""
|
|
argtypes = []
|
|
for _, arg in self.halide_argdefs():
|
|
if isinstance(arg, SizeArg):
|
|
shape = None
|
|
stride = None
|
|
offset = None
|
|
dtype = "long"
|
|
else:
|
|
shape = [
|
|
cexpr(self.rename_indexing(x.size))
|
|
for x in self.buffer_dimensions[arg.name]
|
|
]
|
|
stride = [
|
|
cexpr(self.rename_indexing(x.stride))
|
|
for x in self.buffer_dimensions[arg.name]
|
|
]
|
|
assert len(shape) == len(stride)
|
|
offset = cexpr(self.buffer_offsets[arg.name])
|
|
dtype = f"{DTYPE_TO_CPP[arg.dtype]}*"
|
|
argtypes.append(
|
|
HalideInputSpec(
|
|
dtype,
|
|
arg.name,
|
|
shape=shape,
|
|
stride=stride,
|
|
offset=offset,
|
|
alias_of=arg.alias_of,
|
|
)
|
|
)
|
|
|
|
current_device = V.graph.get_current_device_or_throw()
|
|
if current_device.type == "cpu":
|
|
target = [config.halide.cpu_target]
|
|
scheduler = config.halide.scheduler_cpu
|
|
scheduler_flags = {
|
|
"parallelism": parallel_num_threads(),
|
|
}
|
|
cuda_device = None
|
|
else:
|
|
assert current_device.type == "cuda", "only cpu/cuda supported"
|
|
assert current_device.index <= 0, "only default device supported"
|
|
target = [config.halide.gpu_target]
|
|
scheduler = config.halide.scheduler_cuda
|
|
capability = torch.cuda.get_device_properties(current_device)
|
|
if "cuda_capability" not in target[0]:
|
|
for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]:
|
|
if capability.major >= major and capability.minor >= minor:
|
|
target.append(f"cuda_capability_{major}{minor}")
|
|
break
|
|
target.append("user_context")
|
|
scheduler_flags = {
|
|
"parallelism": capability.multi_processor_count,
|
|
# TODO(jansel): explore other flags, see:
|
|
# grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp
|
|
}
|
|
cuda_device = max(0, current_device.index)
|
|
|
|
# strict_float is requires for correctness
|
|
target.append("strict_float")
|
|
|
|
# without this we will initialize cuda once per kernel and hit errors
|
|
target.append("no_runtime")
|
|
|
|
if not config.halide.asserts:
|
|
target.append("no_asserts")
|
|
|
|
if config.halide.debug:
|
|
target.append("debug")
|
|
|
|
if "64" in self.index_dtype:
|
|
# TODO(jansel): it is unclear if this does anything, since input sizes are still int32
|
|
target.append("large_buffers")
|
|
|
|
return HalideMeta(
|
|
argtypes,
|
|
target="-".join(target),
|
|
scheduler=scheduler,
|
|
scheduler_flags=scheduler_flags, # type: ignore[arg-type]
|
|
cuda_device=cuda_device,
|
|
)
|
|
|
|
def codegen_kernel(self, name=None):
|
|
"""Called at the end to generate a final kernel string"""
|
|
if self.args.inplace_buffers:
|
|
raise Unsupported("inplace_buffers")
|
|
meta = self.halide_kernel_meta() # ensure needed args are added early
|
|
code = IndentedBuffer()
|
|
code.splice(
|
|
"""
|
|
import halide as hl
|
|
from torch._inductor.runtime import halide_helpers
|
|
from math import inf, nan
|
|
|
|
@hl.generator(name="kernel")
|
|
class Kernel:
|
|
""",
|
|
strip=True,
|
|
)
|
|
code.do_indent()
|
|
for _, arg in self.halide_argdefs():
|
|
if isinstance(arg, SizeArg):
|
|
code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})")
|
|
else:
|
|
assert arg.buffer, arg
|
|
argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer"
|
|
argtype = halide_type(arg.dtype)
|
|
ndim = len(self.buffer_dimensions[arg.name])
|
|
code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})")
|
|
code.splice(
|
|
"""
|
|
def generate(g):
|
|
"""
|
|
)
|
|
code.do_indent()
|
|
for _, arg in self.halide_argdefs():
|
|
code.writeline(f"{arg.name} = g.{arg.name}")
|
|
for old, new in self.args.aliases():
|
|
code.writeline(f"{old} = {new}")
|
|
code.splice(self.indexing_code)
|
|
|
|
def update_index(m):
|
|
var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)])
|
|
assert var.used_dims is not None, var
|
|
return str(var)
|
|
|
|
for line in self.body._lines:
|
|
if isinstance(line, str):
|
|
# fill in missing indices
|
|
line = HalideCSEVariable.undefined_re.sub(update_index, line)
|
|
code.writeline(line)
|
|
code.writeline("")
|
|
code.writeline("assert g.using_autoscheduler()")
|
|
|
|
for _, arg in self.halide_argdefs():
|
|
# fallback=1 below because halide requires buffers to be at least as large as the estimates
|
|
# This causes crashes if our estimate is greater than the vector length
|
|
# https://github.com/halide/Halide/issues/3103
|
|
if isinstance(arg, SizeArg):
|
|
hint = V.graph.sizevars.size_hint(arg.expr, fallback=1)
|
|
code.writeline(f"{arg.name}.set_estimate({hint})")
|
|
else:
|
|
dims = self.buffer_dimensions[arg.name]
|
|
range_hints = []
|
|
for i, dim in enumerate(dims):
|
|
hint = self._autoscheduler_workarounds(
|
|
V.graph.sizevars.size_hint(dim.size, fallback=1), dims
|
|
)
|
|
# pyrefly: ignore [bad-argument-type]
|
|
range_hints.append(f"hl.Range(0, {hint})")
|
|
if "out" not in arg.name:
|
|
code.writeline(f"{arg.name}.dim({i}).set_min(0)")
|
|
try:
|
|
code.writeline(
|
|
f"{arg.name}.dim({i}).set_stride({int(dim.stride)})"
|
|
)
|
|
except TypeError:
|
|
pass # not integer
|
|
try:
|
|
code.writeline(
|
|
f"{arg.name}.dim({i}).set_extent({int(dim.size)})"
|
|
)
|
|
except TypeError:
|
|
pass # not integer
|
|
code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])")
|
|
|
|
code.do_unindent(2)
|
|
code.splice(
|
|
"""
|
|
if __name__ == "__main__":
|
|
hl.main()
|
|
""".rstrip(),
|
|
)
|
|
if meta.scheduler:
|
|
code.splice(
|
|
f"""
|
|
else:
|
|
hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
|
|
target = hl.Target({meta.target!r})
|
|
autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
|
|
with hl.GeneratorContext(target, autoscheduler):
|
|
gen = Kernel()
|
|
pipeline = gen._build_pipeline()
|
|
# gen.compile_to_callable() does not run the autoscheduler
|
|
pipeline.apply_autoscheduler(target, autoscheduler)
|
|
kernel = pipeline.compile_to_callable([
|
|
gen._get_input_parameter(a.name)._to_argument()
|
|
for a in gen._get_arginfos()
|
|
if a.dir == hl.ArgInfoDirection.Input
|
|
], target)
|
|
""",
|
|
strip=True,
|
|
)
|
|
else:
|
|
code.splice(
|
|
f"""
|
|
else:
|
|
with hl.GeneratorContext(hl.Target({meta.target!r})):
|
|
kernel = Kernel().compile_to_callable()
|
|
""",
|
|
strip=True,
|
|
)
|
|
return code.getvalue()
|
|
|
|
@staticmethod
|
|
def _autoscheduler_workarounds(n, dims):
|
|
if (
|
|
len(dims) == 1
|
|
and config.halide.scheduler_cuda == "Anderson2021"
|
|
and V.graph.get_current_device_or_throw().type == "cuda"
|
|
):
|
|
# workaround https://github.com/halide/Halide/issues/8246
|
|
n = max(2, n)
|
|
return n
|
|
|
|
def call_kernel(self, name: str, node=None, deallocate_ws: bool = True):
|
|
"""Codegen a call to this kernel"""
|
|
wrapper = V.graph.wrapper_code
|
|
call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None]
|
|
current_device = V.graph.get_current_device_or_throw()
|
|
if current_device.type == "cuda":
|
|
stream_name = wrapper.write_get_raw_stream(
|
|
current_device.index, V.graph.name
|
|
)
|
|
call_args.append(stream_name)
|
|
wrapper.generate_kernel_call(
|
|
name,
|
|
call_args,
|
|
device=current_device,
|
|
triton=False,
|
|
)
|
|
|
|
def generate_assert(self, check):
|
|
return False # TODO(jansel): support asserts
|
|
|
|
def check_bounds(
|
|
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
|
):
|
|
pass # TODO(jansel): support asserts
|
|
|
|
|
|
class HalideScheduling(SIMDScheduling):
|
|
kernel_type = HalideKernel # type: ignore[arg-type,assignment]
|
|
|
|
@classmethod
|
|
def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
|
|
result = OrderedSet(
|
|
[
|
|
BackendFeature.TUPLE_REDUCTION,
|
|
BackendFeature.PREFER_STORE_LOOP_ORDER,
|
|
BackendFeature.REDUCE_TO_SINGLE_ELEMENT,
|
|
]
|
|
)
|
|
if config.halide.scan_kernels:
|
|
result.add(BackendFeature.SCAN)
|
|
return result
|
|
|
|
def define_kernel(self, src_code, node_schedule, kernel):
|
|
"""Codegen kernel definition to go in output wrapper code"""
|
|
wrapper = V.graph.wrapper_code
|
|
if src_code in wrapper.src_to_kernel:
|
|
kernel_name = wrapper.src_to_kernel[src_code]
|
|
else:
|
|
kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}"
|
|
wrapper.src_to_kernel[src_code] = kernel_name
|
|
wrapper.add_import_once(
|
|
"from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec"
|
|
)
|
|
|
|
compile_wrapper = IndentedBuffer()
|
|
compile_wrapper.writeline(
|
|
f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''"
|
|
)
|
|
compile_wrapper.splice(src_code, strip=True)
|
|
compile_wrapper.writeline("''')")
|
|
|
|
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
|
metadata_comment = f"{origins}\n{detailed_origins}"
|
|
wrapper.define_kernel(
|
|
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
|
)
|
|
if is_metric_table_enabled("kernel_metadata"):
|
|
log_kernel_metadata(kernel_name, "", src_code)
|
|
|
|
return kernel_name
|