mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417 Approved by: https://github.com/shunting314, https://github.com/eellison
1147 lines
39 KiB
Python
1147 lines
39 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import itertools
|
|
import logging
|
|
import re
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch._logging
|
|
from ..._prims_common import is_integer_dtype
|
|
from ...utils._sympy.symbol import symbol_is_type, SymT
|
|
from ...utils._sympy.value_ranges import ValueRanges
|
|
from .. import config, ir
|
|
from ..codecache import HalideCodeCache
|
|
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
|
|
|
from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint
|
|
from ..utils import (
|
|
get_bounds_index_expr,
|
|
get_kernel_metadata,
|
|
parallel_num_threads,
|
|
sympy_dot,
|
|
sympy_index_symbol,
|
|
sympy_subs,
|
|
)
|
|
from ..virtualized import _ops as ops, OpsHandler, V
|
|
from .common import (
|
|
BackendFeature,
|
|
CSEVariable,
|
|
DeferredLine,
|
|
IndentedBuffer,
|
|
OpOverrides,
|
|
PythonPrinter,
|
|
SizeArg,
|
|
)
|
|
from .cpp import DTYPE_TO_CPP
|
|
from .cpp_utils import cexpr
|
|
from .simd import constant_repr, IterationRangesEntry, SIMDKernel, SIMDScheduling
|
|
|
|
if TYPE_CHECKING:
|
|
from ..ops_handler import ReductionType, StoreMode
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def halide_constant(val):
|
|
if isinstance(val, int) and not (-2147483648 <= val <= 2147483647):
|
|
info = torch.iinfo(torch.int64)
|
|
if val == info.min:
|
|
return "hl.Int(64).min()"
|
|
if val == info.max:
|
|
return "hl.Int(64).max()"
|
|
return f"hl.i64({val!r})"
|
|
if isinstance(val, float):
|
|
return f"hl.f64({constant_repr(val)})"
|
|
return repr(val)
|
|
|
|
|
|
class Unsupported(RuntimeError):
|
|
def __init__(self, thing):
|
|
super().__init__(f"halide backend does not support: {thing}")
|
|
|
|
|
|
class HalidePrinter(PythonPrinter):
|
|
@staticmethod
|
|
def cast_index(expr):
|
|
return f"hl.cast({V.kernel.index_dtype}, {expr})"
|
|
|
|
@staticmethod
|
|
def cast_float(expr):
|
|
return f"hl.cast(hl.Float(32), {expr})"
|
|
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
|
|
|
|
def _print_Trunc(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.ceil({self._print(expr.args[0])})")
|
|
|
|
def _helper_sqrt(self, expr):
|
|
return f"hl.sqrt({self.cast_float(self._print(expr))})"
|
|
|
|
def _print_Where(self, expr):
|
|
c = self.doprint(expr.args[0])
|
|
p = self.doprint(expr.args[1])
|
|
q = self.doprint(expr.args[2])
|
|
return f"hl.select({c}, {p}, {q})"
|
|
|
|
def _print_Min(self, expr):
|
|
if len(expr.args) == 1:
|
|
return self._print(expr.args[0])
|
|
|
|
mid = len(expr.args) // 2
|
|
a = self._print(sympy.Min(*expr.args[:mid]))
|
|
b = self._print(sympy.Min(*expr.args[mid:]))
|
|
return f"hl.min({a}, {b})"
|
|
|
|
def _print_Max(self, expr):
|
|
if len(expr.args) == 1:
|
|
return self._print(expr.args[0])
|
|
|
|
mid = len(expr.args) // 2
|
|
a = self._print(sympy.Max(*expr.args[:mid]))
|
|
b = self._print(sympy.Max(*expr.args[mid:]))
|
|
|
|
return f"hl.max({a}, {b})"
|
|
|
|
def _print_Abs(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.abs({self._print(expr.args[0])})")
|
|
|
|
def _print_OpaqueUnaryFn_cos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.cos(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_cosh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.cosh(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_acos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.acos(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.sin(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sinh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.sinh(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_asin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.asin(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.tan(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tanh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.tanh(({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_atan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.atan(({self._print(expr.args[0])})"
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
if expr.is_integer:
|
|
return super()._print_FloorDiv(expr)
|
|
|
|
x, div = expr.args
|
|
x = self.cast_float(self.paren(self.doprint(x)))
|
|
div = self.cast_float(self.paren(self.doprint(div)))
|
|
return self.cast_index(f"hl.floor({x} / {div})")
|
|
|
|
def _print_Round(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.round({self._print(expr.args[0])})")
|
|
|
|
_print_RoundToInt = _print_Round
|
|
|
|
def _print_IntTrueDiv(self, expr):
|
|
a, b = expr.args
|
|
# force a cast to float
|
|
return f"({a}) / ({b}+hl.f32(0))"
|
|
|
|
def _print_RoundDecimal(self, expr):
|
|
val, n = expr.args
|
|
val = self._print(val)
|
|
n = int(n)
|
|
return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))"
|
|
|
|
|
|
texpr = HalidePrinter().doprint
|
|
pexpr = PythonPrinter().doprint
|
|
|
|
|
|
_halide_type = {
|
|
torch.bool: "hl.Bool()",
|
|
torch.float16: "hl.Float(16)",
|
|
torch.float32: "hl.Float(32)",
|
|
torch.float64: "hl.Float(64)",
|
|
torch.int8: "hl.Int(8)",
|
|
torch.int16: "hl.Int(16)",
|
|
torch.int32: "hl.Int(32)",
|
|
torch.int64: "hl.Int(64)",
|
|
torch.uint8: "hl.UInt(8)",
|
|
torch.uint16: "hl.UInt(16)",
|
|
torch.uint32: "hl.UInt(32)",
|
|
torch.uint64: "hl.UInt(64)",
|
|
}
|
|
|
|
|
|
def halide_type(dtype):
|
|
if dtype == torch.bfloat16:
|
|
raise Unsupported("torch.bfloat16")
|
|
return _halide_type[dtype]
|
|
|
|
|
|
def halide_acc_type(dtype):
|
|
if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64:
|
|
dtype = torch.int32
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
dtype = torch.float32
|
|
return halide_type(dtype)
|
|
|
|
|
|
class HalideOverrides(OpOverrides):
|
|
@staticmethod
|
|
def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
|
|
if dtype == torch.bool:
|
|
return f"({x} != 0)"
|
|
return f"hl.cast({halide_type(dtype)}, {x})"
|
|
|
|
@staticmethod
|
|
def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
|
|
if src_dtype in (torch.float16, torch.bfloat16):
|
|
x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32
|
|
line = f"hl.reinterpret({halide_type(dtype)}, {x})"
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
line = f"hl.cast(hl.Float(32), {line})"
|
|
return line
|
|
|
|
@classmethod
|
|
def constant(cls, value, dtype):
|
|
return cls.to_dtype(halide_constant(value), dtype)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return f"hl.abs({x})"
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_exp(x):
|
|
return f"hl.exp({x})" # higher precision that ops.exp
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return f"hl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
# return f"hl.min({a}, {b})" <== handles nan wrong
|
|
return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})"
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
# return f"hl.max({a}, {b})" <== handles nan wrong
|
|
return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})"
|
|
|
|
@staticmethod
|
|
def where(a, b, c):
|
|
return f"hl.select({a}, {b}, hl.cast({b.name}.type(), {c}))"
|
|
|
|
@staticmethod
|
|
def cos(x):
|
|
return f"hl.cos({x})"
|
|
|
|
@staticmethod
|
|
def sin(x):
|
|
return f"hl.sin({x})"
|
|
|
|
@staticmethod
|
|
def lgamma(x):
|
|
raise Unsupported("lgamma")
|
|
|
|
@staticmethod
|
|
def erf(x):
|
|
return f"hl.erf({x})"
|
|
|
|
@staticmethod
|
|
def cosh(x):
|
|
return f"hl.cosh({x})"
|
|
|
|
@staticmethod
|
|
def sinh(x):
|
|
return f"hl.sinh({x})"
|
|
|
|
@staticmethod
|
|
def acos(x):
|
|
return f"hl.acos({x})"
|
|
|
|
@staticmethod
|
|
def acosh(x):
|
|
return f"hl.acosh({x})"
|
|
|
|
@staticmethod
|
|
def asin(x):
|
|
return f"hl.asin({x})"
|
|
|
|
@staticmethod
|
|
def asinh(x):
|
|
return f"hl.asinh({x})"
|
|
|
|
@staticmethod
|
|
def atan2(x, y):
|
|
return f"hl.atan2({x}, {y})"
|
|
|
|
@staticmethod
|
|
def atan(x):
|
|
return f"hl.atan({x})"
|
|
|
|
@staticmethod
|
|
def atanh(x):
|
|
return f"hl.atanh({x})"
|
|
|
|
@staticmethod
|
|
def copysign(x, y):
|
|
raise Unsupported("copysign")
|
|
|
|
@staticmethod
|
|
def erfinv(x):
|
|
raise Unsupported("erfinv")
|
|
|
|
@staticmethod
|
|
def hypot(x, y):
|
|
return f"hl.hypot({x}, {y})"
|
|
|
|
@staticmethod
|
|
def nextafter(x, y):
|
|
raise Unsupported("nextafter")
|
|
|
|
@staticmethod
|
|
def logical_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def logical_not(a):
|
|
return f"{a} == 0"
|
|
|
|
@staticmethod
|
|
def logical_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def logical_xor(a, b):
|
|
return f"({a} ^ {b})"
|
|
|
|
@staticmethod
|
|
def bitwise_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_not(a):
|
|
return f"~{a}"
|
|
|
|
@staticmethod
|
|
def bitwise_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_xor(a, b):
|
|
return f"{a} ^ {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_left_shift(a, b):
|
|
return f"{a} << {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_right_shift(a, b):
|
|
return f"{a} >> {b}"
|
|
|
|
@staticmethod
|
|
def rand(seed, offset):
|
|
raise Unsupported("rand")
|
|
|
|
@staticmethod
|
|
def randn(seed, offset):
|
|
raise Unsupported("rand")
|
|
|
|
@staticmethod
|
|
def randint64(seed, offset, low, high):
|
|
raise Unsupported("rand")
|
|
|
|
@staticmethod
|
|
def load_seed(name, offset):
|
|
raise Unsupported("rand")
|
|
|
|
@staticmethod
|
|
def rsqrt(x):
|
|
# return f"hl.fast_inverse_sqrt({x})" <== accuracy issues
|
|
return f"1./hl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def tan(x):
|
|
return f"hl.tan({x})"
|
|
|
|
@staticmethod
|
|
def tanh(x):
|
|
return f"hl.tanh({x})"
|
|
|
|
@staticmethod
|
|
def signbit(x):
|
|
return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0"
|
|
|
|
@staticmethod
|
|
def fmod(a, b):
|
|
# TODO(jansel): find a better way to do this, builtin % has wrong sign
|
|
return f"{a} - hl.trunc({a}/{b})*{b}"
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return f"hl.log({x})" # hl.fast_log fails accuracy
|
|
|
|
@staticmethod
|
|
def isinf(x):
|
|
return f"hl.is_inf({x})"
|
|
|
|
@staticmethod
|
|
def isnan(x):
|
|
return f"hl.is_nan({x})"
|
|
|
|
@staticmethod
|
|
def round(x):
|
|
return f"hl.round({x})"
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return f"hl.floor({x})"
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
return f"({a}) / ({b} + hl.f32(0))"
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
# TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work
|
|
return (
|
|
f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
|
|
)
|
|
|
|
@classmethod
|
|
def sign(cls, x):
|
|
left = ops.to_dtype(ops.lt("0", x), torch.int8)
|
|
right = ops.to_dtype(ops.lt(x, "0"), torch.int8)
|
|
sub = ops.sub(left, right)
|
|
return f"hl.cast({x.name}.type(), {sub})"
|
|
|
|
@staticmethod
|
|
def trunc(x):
|
|
return f"hl.trunc({x})"
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
# this causes crashes with floating point exception, see test_div_zero_dim_cpu
|
|
# return f"hl.div_round_to_zero({a}, {b})"
|
|
return (
|
|
f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
|
|
)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return f"hl.ceil({x})"
|
|
|
|
@staticmethod
|
|
def relu(x):
|
|
return f"hl.max({x}, 0)"
|
|
|
|
@classmethod
|
|
def index_expr(cls, expr, dtype):
|
|
index = V.kernel.prepare_indexing(expr)
|
|
var = V.kernel.genfunc(
|
|
V.kernel.index_to_str(index),
|
|
V.kernel.used_dims_from_index(index),
|
|
bounds=get_bounds_index_expr(expr),
|
|
)
|
|
if dtype not in {torch.int32, torch.int64}:
|
|
return ops.to_dtype(var, dtype)
|
|
return var
|
|
|
|
@classmethod
|
|
def indirect_indexing(cls, index_var, size, check=True):
|
|
# TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow
|
|
index_var = ops.to_dtype(index_var, torch.int32)
|
|
index_var = ops.halide_clamp(index_var, size, check)
|
|
return sympy_index_symbol(str(index_var))
|
|
|
|
@classmethod
|
|
def halide_clamp(cls, value, size, check):
|
|
end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1)
|
|
if not isinstance(size, (int, sympy.Integer)):
|
|
end = f"hl.cast({value.name}.type(), {end})"
|
|
# Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692
|
|
# return f"hl.unsafe_promise_clamped({value}, 0, {end})"
|
|
return f"hl.clamp({value}, 0, {end})"
|
|
|
|
@staticmethod
|
|
def masked(mask, body, other):
|
|
with V.kernel.mask_loads(mask, other) as new_mask:
|
|
result = body()
|
|
|
|
if result.bounds.is_bool:
|
|
other = bool(other)
|
|
|
|
# Take dtype from result to prevent accidental promotion
|
|
other = V.kernel.genfunc(
|
|
f"hl.cast({result.name}.type(), {halide_constant(other)})",
|
|
[],
|
|
bounds=ValueRanges.wrap(other),
|
|
)
|
|
# TODO(jansel): look into removing the where in the same places triton does
|
|
return ops.where(new_mask, result, other)
|
|
|
|
|
|
# Use mypy to check protocol implemented correctly
|
|
def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]:
|
|
return h
|
|
|
|
|
|
class HalideCSEVariable(CSEVariable):
|
|
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
|
|
|
|
def __init__(self, name, bounds: ValueRanges[Any]):
|
|
super().__init__(name, bounds)
|
|
self.used_dims: Optional[List[str]] = None
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
used = set(self.used_dims or ())
|
|
for arg in itertools.chain(args, kwargs.values()):
|
|
if isinstance(arg, HalideCSEVariable):
|
|
assert arg.used_dims is not None, (name, arg, args)
|
|
used.update(arg.used_dims)
|
|
self.used_dims = [t.name for t in V.kernel.range_trees if t.name in used]
|
|
assert len(self.used_dims) == len(used)
|
|
|
|
def index_str(self, dims):
|
|
if len(dims) == 0:
|
|
return self.name
|
|
# Reversed since Halide is column major
|
|
return f"{self.name}[{', '.join(map(str, reversed(dims)))}]"
|
|
|
|
def __str__(self):
|
|
if self.used_dims is None:
|
|
# This will get recomputed and replaced in codegen_kernel()
|
|
return f"{self.name}[?]"
|
|
return self.index_str(self.used_dims)
|
|
|
|
def with_dom(self, suffix):
|
|
assert self.used_dims is not None
|
|
return self.index_str([f"{d}_{suffix}" for d in self.used_dims])
|
|
|
|
def reduction_str(self):
|
|
assert self.used_dims is not None
|
|
dims = [*self.used_dims]
|
|
assert dims[-1] == "rindex"
|
|
dims[-1] = "rdom"
|
|
return self.index_str(dims)
|
|
|
|
|
|
class HalideKernel(SIMDKernel):
|
|
overrides = HalideOverrides # type: ignore[assignment]
|
|
kexpr: Callable[[sympy.Expr], str] = texpr
|
|
|
|
def __init__(
|
|
self,
|
|
*groups,
|
|
index_dtype: str,
|
|
mutations: Optional[Set[str]] = None,
|
|
pid_cache=None,
|
|
reduction_hint=ReductionHint.DEFAULT,
|
|
disable_persistent_reduction=False,
|
|
):
|
|
super().__init__(
|
|
*groups,
|
|
index_dtype=index_dtype,
|
|
mutations=mutations,
|
|
reduction_hint=reduction_hint,
|
|
pid_cache=pid_cache,
|
|
disable_persistent_reduction=disable_persistent_reduction,
|
|
)
|
|
# For halide, we just write directly to the body
|
|
self.compute = self.body
|
|
self.loads = self.body
|
|
self.stores = self.body
|
|
self.indexing_code_dom = IndentedBuffer()
|
|
self.needs_dom_indexing = self.inside_reduction
|
|
self.has_reduction = self.inside_reduction
|
|
self.store_buffer_dimensions: Dict[str, List[sympy.Expr]] = {}
|
|
|
|
def create_cse_var(self, name, bounds=None):
|
|
self.body.writeline(f"{name} = hl.Func({name!r})")
|
|
return HalideCSEVariable(name, bounds)
|
|
|
|
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
|
|
expr = self.rename_indexing(entry.expr)
|
|
self.indexing_code.writeline(f"{entry.name} = {self.kexpr(expr)}")
|
|
|
|
if self.has_reduction:
|
|
# idom includes iteration ranges of the numel of inputs
|
|
expr_idom = sympy_subs(
|
|
expr,
|
|
{
|
|
tree.symbol(): sympy_index_symbol(f"{tree.name}_idom")
|
|
for tree in self.range_trees
|
|
},
|
|
)
|
|
self.indexing_code_dom.writeline(
|
|
f"{entry.name}_idom = {self.kexpr(expr_idom)}"
|
|
)
|
|
|
|
if entry.prefix != "r":
|
|
# idom includes iteration ranges of the numel of outputs (which is different for reductions)
|
|
expr_idom = sympy_subs(
|
|
expr,
|
|
{
|
|
tree.symbol(): sympy_index_symbol(f"{tree.name}_odom")
|
|
for tree in self.range_trees
|
|
},
|
|
)
|
|
self.indexing_code_dom.writeline(
|
|
f"{entry.name}_odom = {self.kexpr(expr_idom)}"
|
|
)
|
|
|
|
def used_dims_from_index(self, index: sympy.Expr):
|
|
"""Detect which range trees are used to populate HalideCSEVariable.used_dims"""
|
|
used_dims = set()
|
|
for sym in index.free_symbols:
|
|
assert isinstance(sym, sympy.Symbol)
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
# indirect indexing
|
|
cse_var = self.lookup_cse_var(sym.name)
|
|
assert (
|
|
isinstance(cse_var, HalideCSEVariable)
|
|
and cse_var.used_dims is not None
|
|
)
|
|
used_dims.update(cse_var.used_dims)
|
|
elif symbol_is_type(
|
|
sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX)
|
|
):
|
|
pass
|
|
else:
|
|
# sym is one of xN, yN or rN
|
|
assert symbol_is_type(
|
|
sym, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
|
|
), sym.name
|
|
used_dims.add(f"{sym.name[0]}index")
|
|
|
|
ordered = [tree.name for tree in self.range_trees if tree.name in used_dims]
|
|
assert len(ordered) == len(used_dims)
|
|
return ordered
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
"""Codegen a load from an InputBuffer"""
|
|
var = self.args.input(name)
|
|
index = self.prepare_indexing(index)
|
|
index_str = self.index_to_str(index)
|
|
if self.is_indirect_indexing(index) or self._load_mask:
|
|
# Halide doesn't have a great way to do masked loads
|
|
var = f"hl.BoundaryConditions.constant_exterior({var}, 0)"
|
|
line = f"{var}[{index_str}]"
|
|
dtype = V.graph.get_dtype(name)
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
line = f"hl.cast(hl.Float(32), {line})"
|
|
return self.genfunc(line, self.used_dims_from_index(index))
|
|
|
|
def index_to_dom(self, index: sympy.Expr, suffix: str):
|
|
"""Replace xindex => xindex_dom, x0 => x0_dom, etc for update-style indexing"""
|
|
replacements: Dict[sympy.Expr, Any] = {}
|
|
for sym in index.free_symbols:
|
|
assert isinstance(sym, sympy.Symbol)
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
# indirect indexing
|
|
cse_var = self.lookup_cse_var(sym.name)
|
|
assert isinstance(cse_var, HalideCSEVariable)
|
|
replacements[sym] = sympy.Symbol(cse_var.with_dom(suffix))
|
|
elif symbol_is_type(
|
|
sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX)
|
|
):
|
|
pass
|
|
else:
|
|
# sym is one of xN, yN or rN
|
|
assert symbol_is_type(
|
|
sym, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
|
|
), sym.name
|
|
replacements[sym] = sympy.Symbol(f"{sym.name}_{suffix}")
|
|
return sympy_subs(index, replacements)
|
|
|
|
def lookup_cse_var(self, name: str):
|
|
return self.cse.varname_map[re.sub(r"\[.*", "", name)]
|
|
|
|
def determine_store_indexing(
|
|
self, name: str, index: sympy.Expr, value: HalideCSEVariable, var: str, mode
|
|
):
|
|
"""
|
|
Halide requires the initial definition of an output to be done with a plain Var(),
|
|
while subsequent updates can use Expr(). For us index may be an Expr. This function
|
|
tries to make the output index a var, and if that fails switches to the more flexible
|
|
hl.RDom()+update codegen.
|
|
"""
|
|
assert value.used_dims is not None
|
|
assert var not in self.store_buffer_dimensions
|
|
eq = V.graph.sizevars.statically_known_equals
|
|
|
|
if index == 0 and eq(self.halide_buffer_numel(name), 1) and mode is None:
|
|
# 1-element case
|
|
index_str = "hl.Var()" # halide requires storage dst to be a Var
|
|
value_str = value.index_str([0 for _ in value.used_dims])
|
|
return index_str, value_str
|
|
|
|
var_ranges = self.var_ranges()
|
|
range_trees = self.active_range_trees()
|
|
numel = self.halide_buffer_numel(name)
|
|
|
|
if (
|
|
isinstance(index, sympy.Symbol)
|
|
and index in var_ranges
|
|
and eq(var_ranges[index], numel)
|
|
and mode is None
|
|
):
|
|
value_str = str(value)
|
|
index_str = self.index_to_str(index)
|
|
return index_str, value_str
|
|
|
|
try:
|
|
value_index, dim_sizes, index_vars = self.match_strides_to_dimensions(
|
|
index, var_ranges, range_trees, f"{var}_i", mode
|
|
)
|
|
except NotImplementedError:
|
|
pass
|
|
else:
|
|
self.store_buffer_dimensions[var] = dim_sizes
|
|
for v in index_vars:
|
|
self.body.writeline(
|
|
DeferredLine(name, f"{v.name} = hl.Var({v.name!r})")
|
|
)
|
|
index_str = ", ".join(v.name for v in index_vars)
|
|
value_str = value.index_str([value_index[d[0]] for d in value.used_dims])
|
|
return index_str, value_str
|
|
|
|
self.needs_dom_indexing = True
|
|
# Fall back to using RDom-style store
|
|
self.body.writeline(
|
|
DeferredLine(name, f"{var}[hl.Var()] = hl.undef({var}.type())")
|
|
)
|
|
suffix = "idom" if self.inside_reduction else "odom"
|
|
value_str = value.with_dom(suffix)
|
|
index_str = self.index_to_str(self.index_to_dom(index, suffix))
|
|
return index_str, value_str
|
|
|
|
def match_strides_to_dimensions(
|
|
self, index, var_ranges, range_trees, varname, mode
|
|
):
|
|
"""Best effort conversion of 1D indexing into N-D indexing"""
|
|
if mode is not None:
|
|
raise NotImplementedError # atomic_add
|
|
eq = V.graph.sizevars.statically_known_equals
|
|
used_vars = set(index.free_symbols)
|
|
var_ranges = {s: v for s, v in var_ranges.items() if s in used_vars}
|
|
strides = V.graph.sizevars.stride_vars(index, var_ranges)
|
|
if not strides or not eq(sympy_dot(var_ranges, strides), index):
|
|
raise NotImplementedError # complex or indirect indexing
|
|
|
|
tree_numels = {t.prefix: sympy.Integer(1) for t in range_trees}
|
|
prefix_to_tree = {t.prefix: t for t in range_trees}
|
|
expected_stride = sympy.Integer(1)
|
|
new_lengths = []
|
|
new_index = {t.prefix: sympy.Integer(0) for t in range_trees}
|
|
new_vars: List[sympy.Symbol] = []
|
|
for stride, (v, length) in sorted(
|
|
zip(strides, var_ranges.items()),
|
|
key=lambda x: V.graph.sizevars.size_hint(x[0], fallback=float("inf")), # type: ignore[arg-type]
|
|
):
|
|
if not eq(expected_stride, stride):
|
|
raise NotImplementedError # gaps in indexing or unbacked symints
|
|
prefix = v.name[0]
|
|
if prefix_to_tree[prefix].lookup(tree_numels[prefix], length) != v:
|
|
raise NotImplementedError # output reordering
|
|
new_var = sympy.Symbol(f"{varname}{len(new_vars)}")
|
|
new_vars.append(new_var)
|
|
new_lengths.append(length)
|
|
new_index[prefix] += tree_numels[prefix] * new_var
|
|
tree_numels[prefix] *= length
|
|
expected_stride *= length
|
|
return new_index, new_lengths, new_vars
|
|
|
|
def store(
|
|
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
"""Codegen a store to an OutputBuffer"""
|
|
var = self.args.output(name)
|
|
index = self.prepare_indexing(index)
|
|
assert isinstance(value, HalideCSEVariable)
|
|
index_str, value_str = self.determine_store_indexing(
|
|
name, index, value, var, mode
|
|
)
|
|
|
|
if self.is_indirect_indexing(index):
|
|
# Workaround "Buffer out_ptr0 may be accessed in an unbounded way"
|
|
# TODO(jansel): we should error here rather than writing to the first/last element
|
|
index_str = f"hl.clamp({index_str}, 0, {self.kexpr(self.halide_buffer_numel(name) - 1)})"
|
|
|
|
if mode is None:
|
|
line = f"{var}[{index_str}] = hl.cast({var}.type(), {value_str})"
|
|
elif mode == "atomic_add":
|
|
line = f"{var}[{index_str}] += {value_str}"
|
|
else:
|
|
raise NotImplementedError(f"store mode={mode}")
|
|
self.body.writeline(DeferredLine(name, line))
|
|
|
|
def reduction(
|
|
self,
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
|
"""Codegen a reduction operation"""
|
|
assert self.inside_reduction
|
|
assert not self._load_mask
|
|
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
|
|
|
cache_key = (src_dtype, reduction_type, value)
|
|
if cache_key in self.cse.reduction_cache:
|
|
return self.cse.reduction_cache[cache_key]
|
|
|
|
acc_type = halide_acc_type(dtype)
|
|
result_var = self.newfunc(
|
|
[
|
|
tree.name
|
|
for tree in self.range_trees[:-1]
|
|
if tree.name in value.used_dims
|
|
]
|
|
)
|
|
|
|
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
|
|
|
|
if value.used_dims[-1] != "rindex":
|
|
value = self.genfunc(f"{value}", [*value.used_dims, "rindex"])
|
|
value_str = value.reduction_str()
|
|
|
|
if reduction_type in ("argmax", "argmin"):
|
|
self.body.writeline(
|
|
f"{result_var} = hl.{reduction_type}(rdom, {value_str})[0]"
|
|
)
|
|
elif reduction_type in ("sum", "prod", "min", "max", "any"):
|
|
fn = {
|
|
"sum": "sum",
|
|
"prod": "product",
|
|
"min": "minimum",
|
|
"max": "maximum",
|
|
"any": "maximum",
|
|
}[reduction_type]
|
|
self.body.writeline(f"{result_var} = hl.{fn}(rdom, {value_str})")
|
|
elif reduction_type == "xor_sum":
|
|
result_var_init = result_var
|
|
if not result_var.used_dims: # need a fake dim
|
|
result_var_init = result_var.index_str([self.range_trees[0].name])
|
|
result_var.used_dims = ["0"]
|
|
self.body.writeline(
|
|
f"{result_var_init} = hl.cast({acc_type}, {halide_constant(default)})"
|
|
)
|
|
self.body.writeline(f"{result_var} = {result_var} ^ {value_str}")
|
|
elif reduction_type == "welford_reduce":
|
|
# TODO(jansel): implement welford_reduce without fallback
|
|
result_var = self.welford_reduce_fallback(dtype, value)
|
|
else:
|
|
raise Unsupported(reduction_type)
|
|
|
|
self.cse.reduction_cache[cache_key] = result_var
|
|
return result_var
|
|
|
|
def genfunc(
|
|
self, line, used_dims, *, bounds=ValueRanges.unknown()
|
|
) -> HalideCSEVariable:
|
|
var = self.cse.generate(self.body, line, bounds=bounds)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
var.used_dims = used_dims
|
|
return var
|
|
|
|
def newfunc(self, used_dims) -> HalideCSEVariable:
|
|
var = self.cse.newvar()
|
|
assert isinstance(var, HalideCSEVariable)
|
|
var.used_dims = used_dims
|
|
return var
|
|
|
|
def halide_buffer_numel(self, name: str):
|
|
"""
|
|
We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch
|
|
supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while
|
|
PyTorch's numel excludes them.
|
|
"""
|
|
return V.graph.get_buffer(name).get_layout().storage_size()
|
|
|
|
def halide_argdefs(self):
|
|
"""
|
|
Halide requires scalar inputs before outputs, so need to reorder args.
|
|
"""
|
|
|
|
def arg_order(arg_tuple):
|
|
call_str, arg = arg_tuple
|
|
if isinstance(arg, SizeArg):
|
|
return 1 # this would normally be at the end, move it to middle
|
|
elif "out_ptr" in arg.name:
|
|
return 2
|
|
else:
|
|
assert "in_ptr" in arg.name
|
|
return 0
|
|
|
|
_, a, b, _ = self.args.python_argdefs()
|
|
return sorted(zip(a, b), key=arg_order)
|
|
|
|
def halide_kernel_meta(self) -> HalideMeta:
|
|
"""Compute metadata required by codecache.py"""
|
|
argtypes = []
|
|
for _, arg in self.halide_argdefs():
|
|
if isinstance(arg, SizeArg):
|
|
shape = None
|
|
dtype = "long"
|
|
else:
|
|
if arg.name in self.store_buffer_dimensions and "out" in arg.name:
|
|
shape = [
|
|
cexpr(self.rename_indexing(x))
|
|
for x in self.store_buffer_dimensions[arg.name]
|
|
]
|
|
assert shape
|
|
else:
|
|
shape = [
|
|
cexpr(
|
|
self.rename_indexing(self.halide_buffer_numel(arg.buffer))
|
|
)
|
|
] or ["1"]
|
|
dtype = f"{DTYPE_TO_CPP[arg.dtype]}*"
|
|
argtypes.append(
|
|
HalideInputSpec(
|
|
dtype,
|
|
arg.name,
|
|
shape,
|
|
)
|
|
)
|
|
target = ["host", "strict_float"]
|
|
# TODO(jansel): for cuda want target="host-cuda-cuda_capability_86-user_context"
|
|
if not config.halide.asserts:
|
|
target.append("no_asserts")
|
|
if "64" in self.index_dtype:
|
|
# TODO(jansel): it is unclear if this does anything, since input sizes are still int32
|
|
target.append("large_buffers")
|
|
|
|
return HalideMeta(
|
|
argtypes,
|
|
target="-".join(target),
|
|
scheduler="Mullapudi2016",
|
|
scheduler_flags={
|
|
"parallelism": parallel_num_threads(),
|
|
},
|
|
)
|
|
|
|
def codegen_kernel(self, name=None):
|
|
"""Called at the end to generate a final kernel string"""
|
|
if self.args.inplace_buffers:
|
|
raise Unsupported("inplace_buffers")
|
|
meta = self.halide_kernel_meta() # ensure needed args are added early
|
|
code = IndentedBuffer()
|
|
code.splice(
|
|
"""
|
|
import halide as hl
|
|
|
|
@hl.generator(name="kernel")
|
|
class Kernel:
|
|
""",
|
|
strip=True,
|
|
)
|
|
code.do_indent()
|
|
for _, arg in self.halide_argdefs():
|
|
if isinstance(arg, SizeArg):
|
|
code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})")
|
|
else:
|
|
assert arg.buffer, arg
|
|
argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer"
|
|
argtype = halide_type(arg.dtype)
|
|
ndim = len(self.store_buffer_dimensions.get(arg.name, (0,)))
|
|
code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})")
|
|
code.splice(
|
|
"""
|
|
def generate(g):
|
|
"""
|
|
)
|
|
code.do_indent()
|
|
for _, arg in self.halide_argdefs():
|
|
code.writeline(f"{arg.name} = g.{arg.name}")
|
|
for old, new in self.args.aliases():
|
|
code.writeline(f"{old} = {new}")
|
|
|
|
dom_size = {}
|
|
for tree in self.active_range_trees(reorder=True):
|
|
code.writeline(f"{tree.name} = hl.Var({tree.name!r})")
|
|
length = self.kexpr(self.rename_indexing(tree.numel))
|
|
dom_size[tree.name] = f"hl.Range(0, {length})"
|
|
assert len(dom_size) <= 3
|
|
code.splice(self.indexing_code)
|
|
|
|
if self.inside_reduction:
|
|
sizes = [*dom_size.values()]
|
|
code.writeline(f"idom = hl.RDom([{', '.join(sizes)}])")
|
|
code.writeline(f"odom = hl.RDom([{', '.join(sizes[:-1])}])")
|
|
code.writeline(f"rdom = hl.RDom([{sizes[-1]}])")
|
|
for name, xyz in zip(dom_size.keys(), "xyz"):
|
|
code.writeline(f"{name}_idom = idom.{xyz}")
|
|
if name[0] != "r":
|
|
code.writeline(f"{name}_odom = odom.{xyz}")
|
|
elif self.needs_dom_indexing:
|
|
code.writeline(f"odom = hl.RDom([{', '.join(dom_size.values())}])")
|
|
for name, xyz in zip(dom_size.keys(), "xyz"):
|
|
code.writeline(f"{name}_odom = odom.{xyz}")
|
|
|
|
if self.needs_dom_indexing:
|
|
code.splice(self.indexing_code_dom)
|
|
|
|
def update_index(m):
|
|
var = self.cse.varname_map[m.group(1)]
|
|
assert var.used_dims is not None, var
|
|
if var.used_dims:
|
|
return str(var)
|
|
else:
|
|
return var.name # a constant doesn't need to be wrapped in func
|
|
|
|
for line in self.body._lines:
|
|
if isinstance(line, str):
|
|
# fill in missing indices
|
|
line = HalideCSEVariable.undefined_re.sub(update_index, line)
|
|
code.writeline(line)
|
|
code.writeline("")
|
|
code.writeline("assert g.using_autoscheduler()")
|
|
|
|
for _, arg in self.halide_argdefs():
|
|
# fallback=1 below because halide requires buffers to be at least as large as the estimates
|
|
# This causes crashes if our estimate is greater than the vector length
|
|
# https://github.com/halide/Halide/issues/3103
|
|
if isinstance(arg, SizeArg):
|
|
hint = V.graph.sizevars.size_hint(arg.expr, fallback=1)
|
|
code.writeline(f"{arg.name}.set_estimate({hint})")
|
|
else:
|
|
if arg.name in self.store_buffer_dimensions and "out" in arg.name:
|
|
hints = V.graph.sizevars.size_hints(
|
|
self.store_buffer_dimensions[arg.name], fallback=1
|
|
)
|
|
else:
|
|
hints = V.graph.sizevars.size_hints(
|
|
[V.graph.get_numel(arg.buffer)], fallback=1
|
|
)
|
|
range_hints = [f"hl.Range(0, {hint})" for hint in hints]
|
|
code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])")
|
|
|
|
code.do_unindent(2)
|
|
code.splice(
|
|
f"""
|
|
if __name__ == "__main__":
|
|
hl.main()
|
|
else:
|
|
hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
|
|
target = hl.Target({meta.target!r})
|
|
autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
|
|
with hl.GeneratorContext(target, autoscheduler):
|
|
gen = Kernel()
|
|
pipeline = gen._build_pipeline()
|
|
# gen.compile_to_callable() does not run the autoscheduler
|
|
pipeline.apply_autoscheduler(target, autoscheduler)
|
|
kernel = pipeline.compile_to_callable([
|
|
gen._get_input_parameter(a.name)._to_argument()
|
|
for a in gen._get_arginfos()
|
|
if a.dir == hl.ArgInfoDirection.Input
|
|
], target)
|
|
"""
|
|
)
|
|
return code.getvalue()
|
|
|
|
def call_kernel(self, name: str, node=None):
|
|
"""Codegen a call to this kernel"""
|
|
wrapper = V.graph.wrapper_code
|
|
call_args = [f"{n}" for n, _ in self.halide_argdefs()]
|
|
assert V.graph.scheduler.current_device is not None
|
|
current_device = V.graph.scheduler.current_device
|
|
assert current_device.type == "cpu", "TODO"
|
|
wrapper.generate_kernel_call(
|
|
name,
|
|
call_args,
|
|
cuda=False, # grid/stream is handled internally in halide
|
|
)
|
|
|
|
def generate_assert(self, check):
|
|
return False # TODO(jansel): support asserts
|
|
|
|
def check_bounds(
|
|
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
|
):
|
|
pass # TODO(jansel): support asserts
|
|
|
|
|
|
class HalideScheduling(SIMDScheduling):
|
|
int32_type = "hl.Int(32)"
|
|
# TODO(jansel): Halide doesn't actually support 64 bit indexing...
|
|
int64_type = "hl.Int(64)"
|
|
kernel_type = HalideKernel
|
|
|
|
@classmethod
|
|
def get_backend_features(cls, device: torch.device):
|
|
result = dict.fromkeys(
|
|
[
|
|
BackendFeature.TUPLE_REDUCTION,
|
|
]
|
|
)
|
|
return result
|
|
|
|
def define_kernel(self, src_code, node_schedule, kernel):
|
|
"""Codegen kernel definition to go in output wrapper code"""
|
|
wrapper = V.graph.wrapper_code
|
|
if src_code in wrapper.src_to_kernel:
|
|
kernel_name = wrapper.src_to_kernel[src_code]
|
|
else:
|
|
kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}"
|
|
wrapper.src_to_kernel[src_code] = kernel_name
|
|
wrapper.add_import_once(
|
|
"from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec"
|
|
)
|
|
|
|
compile_wrapper = IndentedBuffer()
|
|
compile_wrapper.writeline(
|
|
f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''"
|
|
)
|
|
compile_wrapper.splice(src_code, strip=True)
|
|
compile_wrapper.writeline("''')")
|
|
|
|
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
|
metadata_comment = f"{origins}\n{detailed_origins}"
|
|
wrapper.define_kernel(
|
|
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
|
)
|
|
if is_metric_table_enabled("kernel_metadata"):
|
|
log_kernel_metadata(kernel_name, "", src_code)
|
|
|
|
return kernel_name
|