import contextlib import dataclasses import functools import math import sys from copy import deepcopy from pathlib import Path from typing import Dict, List import sympy import torch from torch._prims_common import is_float_dtype from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen from ..utils import sympy_product, sympy_subs, sympy_symbol from ..virtualized import ops, V from .common import ( BracesBuffer, CppWrapperKernelArgs, DeferredIndentedBuffer, ExprPrinter, IndentedBuffer, Kernel, KernelArgs, OpOverrides, ) DTYPE_TO_CPP = { torch.float32: "float", torch.float64: "double", torch.float16: "half", torch.int64: "long", torch.int32: "int", torch.int16: "short", torch.int8: "signed char", torch.uint8: "unsigned char", torch.bool: "bool", torch.bfloat16: "bfloat16", } DTYPE_TO_ATEN = { torch.float32: "at::ScalarType::Float", torch.float64: "at::ScalarType::Double", torch.float16: "at::ScalarType::Half", torch.int64: "at::ScalarType::Long", torch.int32: "at::ScalarType::Int", torch.int16: "at::ScalarType::Short", torch.int8: "at::ScalarType::Char", torch.uint8: "at::ScalarType::Byte", torch.bool: "at::ScalarType::Bool", torch.bfloat16: "at::ScalarType::BFloat16", } INDEX_TYPE = "long" RTYPE_TO_CPP = { "sum": "+", "min": "min", "max": "max", "argmin": "argmin", "argmax": "argmax", "any": "||", } def reduction_init(reduction_type, dtype): if reduction_type in ("sum", "any"): return 0 if reduction_type in {"max", "argmax"}: return ( f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" if is_float_dtype(dtype) else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::min()" ) if reduction_type in {"min", "argmin"}: return ( f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" if is_float_dtype(dtype) else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()" ) raise AssertionError(reduction_type) def reduction_combine(reduction_type, var, next_value): if reduction_type == "sum": return f"{var} += {next_value}" if reduction_type == "any": return f"{var} = {var} || {next_value}" return f"{var} = std::{reduction_type}({var}, {next_value})" def reduction_combine_vec(reduction_type, var, next_value): if reduction_type == "max": return f"{var} = at::vec::maximum({var}, {next_value})" elif reduction_type == "min": return f"{var} = at::vec::minimum({var}, {next_value})" elif reduction_type == "sum": return f"{var} += {next_value}" else: raise NotImplementedError() index_value_name_counter = 1 def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar): global index_value_name_counter struct_name = f"IndexValue_{index_value_name_counter}" index_value_name_counter += 1 # A small annoyance, due to it being a little cumbersome to just throw {} into strings prefix = [ f"struct {struct_name} {{size_t index; {DTYPE_TO_CPP[src_dtype]} value;}};", f"{struct_name} {tmpvar}{{0, {reduction_init(reduction_type, src_dtype)}}};", ] if reduction_type == "argmax": prefix.extend( [ f"#pragma omp declare reduction(argmax : struct {struct_name} :\\", " omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\\", " omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\\", f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})", ] ) elif reduction_type == "argmin": prefix.extend( [ f"#pragma omp declare reduction(argmin : struct {struct_name} :\\", " omp_out.value = omp_in.value > omp_out.value ? omp_out.value : omp_in.value,\\", " omp_out.index = omp_in.value > omp_out.value ? omp_out.index : omp_in.index)\\", f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})", ] ) return prefix def float16_reduction_prefix(rtype): # TODO: This user-defined reduction uses float16 accumulation for sum. To reduce numerical # errors, float32 accumulation should be used instead. assert rtype in ( "sum", "any", ), f"float16 user-defined reduction only supports 'sum' and 'any' but got {rtype}" prefix = [ f"#pragma omp declare reduction({RTYPE_TO_CPP[rtype]}:{DTYPE_TO_CPP[torch.float16]}:" + f"omp_out = omp_out {RTYPE_TO_CPP[rtype]} omp_in)" ] return prefix def parallel_num_threads(): threads = config.cpp.threads if threads < 1: threads = torch.get_num_threads() return threads @functools.lru_cache() def cpp_prefix(): path = Path(__file__).parent / "cpp_prefix.h" with path.open() as f: _, filename = codecache.write( f.read(), "h", ) return f'#include "{filename}"' class CppPrinter(ExprPrinter): def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) mod = self.paren(self.doprint(mod)) if div != "1": x = f"({x} / {div})" return f"{x} % {mod}" def _print_IndexingDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} / {div})" cexpr = CppPrinter().doprint class CppVecOverrides(OpOverrides): """Map element-wise ops to aten vectorization C++""" @staticmethod def add(a, b): return f"{a} + {b}" @staticmethod def sub(a, b): return f"{a} - {b}" @staticmethod def mul(a, b): return f"{a} * {b}" @staticmethod def div(a, b): return f"{a} / {b}" @staticmethod def abs(x): return f"{x}.abs()" @staticmethod def sin(x): return f"{x}.sin()" @staticmethod def cos(x): return f"{x}.cos()" @staticmethod def exp(x): return f"{x}.exp()" @staticmethod def erf(x): return f"{x}.erf()" @staticmethod def sqrt(x): return f"{x}.sqrt()" @staticmethod def rsqrt(x): return f"{x}.rsqrt()" @staticmethod def pow(a, b): return f"{a}.pow({b})" @staticmethod def log(x): return f"{x}.log()" @staticmethod def round(x): return f"{x}.round()" @staticmethod def floor(x): return f"{x}.floor()" @staticmethod def ceil(x): return f"{x}.ceil()" @staticmethod def trunc(x): return f"{x}.trunc()" @staticmethod def fmod(a, b): return f"{a}.fmod({b})" @staticmethod def lgamma(x): return f"{x}.lgamma()" @staticmethod def logical_and(a, b): return f"{a} && {b}" @staticmethod def logical_or(a, b): return f"{a} || {b}" @staticmethod def tanh(a): return f"{a}.tanh()" @staticmethod def reciprocal(a): return f"{a}.reciprocal()" @staticmethod def constant(val, dtype): if val == float("inf"): quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif val == float("-inf"): quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif math.isnan(val): quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" elif val is True or val is False: quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})" else: quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({repr(val)})" return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({quote})" @staticmethod def relu(x): return f"at::vec::clamp_min({x}, decltype({x})(0))" @staticmethod def sigmoid(x): return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" @staticmethod def neg(x): return f"{x}.neg()" @staticmethod def floordiv(a, b): # a and b are integer type _t = f"decltype({a})" quot = f"{a} / {b}" rem = f"{a} % {b}" return f"(({a} < {_t}(0)) != ({b} < {_t}(0)) ? ({rem} != {_t}(0) ? {quot} - {_t}(1) : {quot}) : {quot})" @staticmethod def truncdiv(a, b): # a and b are integer type return f"{a} / {b}" @staticmethod def minimum(a, b): return f"at::vec::minimum({a}, {b})" @staticmethod def maximum(a, b): return f"at::vec::maximum({a}, {b})" @staticmethod def square(a): return f"{a}.pow(2)" @staticmethod def where(a, b, c): return f"decltype({b})::blendv({c}, {b}, {a})" @staticmethod def sign(x): code = BracesBuffer() # auto tmp5 = tmp4 < 0 ? -1 : 1; vec_zero = f"decltype({x})(0)" vec_one = f"decltype({x})(1)" blendv = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" left = V.kernel.cse.newvar() code.writeline(f"auto {left} = {blendv};") # auto tmp6 = tmp4 == 0 ? 0 : tmp5; blendv = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" right = V.kernel.cse.newvar() code.writeline(f"auto {right} = {blendv};") result = V.kernel.cse.newvar() code.writeline(f"auto {result} = {left} - {right};") V.kernel.compute.splice(code) return result @staticmethod def to_dtype(x, dtype): assert dtype in [torch.bool], f"{__name__} does not support {dtype}" return f"({x})" @staticmethod def expm1(x): return f"{x}.expm1()" @staticmethod def log1p(x): return f"{x}.log1p()" class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @staticmethod def to_dtype(x, dtype): assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" return f"static_cast<{DTYPE_TO_CPP[dtype]}>({x})" @staticmethod def abs(x): return f"std::abs({x})" @staticmethod def sin(x): return f"std::sin({x})" @staticmethod def cos(x): return f"std::cos({x})" @staticmethod def exp(x): # return f"Sleef_expf_u10({x})" return f"std::exp({x})" @staticmethod def erf(x): return f"std::erf({x})" @staticmethod def sqrt(x): return f"std::sqrt({x})" @staticmethod def rsqrt(x): return f"1 / std::sqrt({x})" @staticmethod def log1p(x): return f"std::log1p({x})" @staticmethod def expm1(x): return f"std::expm1({x})" @staticmethod def tanh(x): return f"std::tanh({x})" @staticmethod def signbit(x): return f"std::signbit({x})" @staticmethod def pow(a, b): return f"std::pow({a}, {b})" @staticmethod def log(x): return f"std::log({x})" @staticmethod def round(x): return f"std::nearbyint({x})" @staticmethod def floor(x): return f"std::floor({x})" @staticmethod def floordiv(a, b): # a and b are integer type quot = f"{a} / {b}" rem = f"{a} % {b}" return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" @staticmethod def ceil(x): return f"std::ceil({x})" @staticmethod def trunc(x): return f"std::trunc({x})" @staticmethod def truncdiv(a, b): # a and b are integer type return f"{a} / {b}" @staticmethod def fmod(a, b): return f"std::fmod({a}, {b})" @staticmethod def isinf(x): return f"std::isinf({x})" @staticmethod def isnan(x): return f"std::isnan({x})" @staticmethod def lgamma(x): return f"std::lgamma({x})" @staticmethod def relu(x): return f"{x} * ({x}>0)" @staticmethod def minimum(a, b): return f"({b} != {b}) ? {b} : std::min({a}, {b})" @staticmethod def maximum(a, b): return f"({b} != {b}) ? {b} : std::max({a}, {b})" @staticmethod def where(a, b, c): return f"{a} ? {b} : {c}" @staticmethod def mod(a, b): return f"mod({a}, {b})" @staticmethod def constant(val, dtype): if val == float("inf"): return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif val == float("-inf"): return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif math.isnan(val): return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" elif val is True or val is False: return ops.to_dtype(str(val).lower(), dtype) return ops.to_dtype(repr(val), dtype) @staticmethod def index_expr(expr, dtype): return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype) @staticmethod def masked(mask, body, other): code = BracesBuffer() var = V.kernel.cse.newvar() if other == float("-inf"): code.writeline(f"float {var} = -std::numeric_limits::infinity();") elif other == float("inf"): code.writeline(f"float {var} = std::numeric_limits::infinity();") elif isinstance(other, float): code.writeline(f"float {var} = {other};") else: code.writeline(f"auto {var} = {other!r};") code.writeline(f"if({mask})") with V.kernel.swap_buffers(code), code.indent(): result = body() code.writeline(f"{var} = {result};") V.kernel.compute.splice(code) return var @staticmethod def logical_and(a, b): return f"{a} && {b}" @staticmethod def logical_or(a, b): return f"{a} || {b}" @staticmethod def rand(seed: sympy.Expr, offset: sympy.Expr, dtype): return f"static_cast<{DTYPE_TO_CPP[dtype]}>(normalized_rand_cpu({seed}, {offset}));" @staticmethod def randn(seed: sympy.Expr, offset: sympy.Expr, dtype): return f"static_cast<{DTYPE_TO_CPP[dtype]}>(randn_cpu({seed}, {offset}));" @staticmethod def sigmoid(x): x = ops.exp(f"-{x}") return f"1 / (1 + {x})" @staticmethod def sign(x): code = BracesBuffer() # auto tmp5 = tmp4 < 0 ? -1 : 1; left = V.kernel.cse.newvar() right = V.kernel.cse.newvar() result = V.kernel.cse.newvar() code.writeline(f"auto {left} = {x} > 0 ? 1 : 0;") code.writeline(f"auto {right} = {x} < 0 ? 1 : 0;") code.writeline(f"auto {result} = {left} - {right};") V.kernel.compute.splice(code) return result class CppKernel(Kernel): overrides = CppOverrides sexpr = cexpr newvar_prefix = "auto " suffix = ";" def __init__(self, args, num_threads): super(CppKernel, self).__init__(args) self.call_ranges = None self.ranges = None self.itervars = None self.reduction_depth = None self.reduction_prefix = IndentedBuffer() self.reduction_suffix = DeferredIndentedBuffer() self.reduction_vars = {} self.num_threads = num_threads # num_threads the kernel specialized for def load(self, name: str, index: sympy.Expr): var = self.args.input(name) index = self.rename_indexing(index) line = f"{var}[{cexpr(index)}]" if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16): line = f"static_cast({line})" return self.cse.generate(self.loads, line) def store(self, name, index, value, mode=None): assert "buf" in name var = self.args.output(name) index = self.rename_indexing(index) if mode is None: line = f"{var}[{cexpr(index)}] = {value};" elif mode == "atomic_add": if not config.cpp.dynamic_threads and self.num_threads == 1: line = f"{var}[{cexpr(index)}] += {value};" else: line = f"atomic_add(&{var}[{cexpr(index)}], {value});" else: raise NotImplementedError(f"store mode={mode}") self.stores.writeline(name, line) def reduction(self, name, dtype, src_dtype, reduction_type, index, value): argmax_or_argmin = reduction_type in {"argmax", "argmin"} tmpvar = self.cse.generate( self.loads, f"reduction {name} {cexpr(index)}", write=False ) index = self.rename_indexing(index) self.reduction_vars[tmpvar] = reduction_type if argmax_or_argmin: self.reduction_prefix.writelines( argmax_argmin_prefix(reduction_type, src_dtype, tmpvar) ) compare_op = "<" if reduction_type == "argmax" else ">" self.stores.writelines( None, [ f"if ({tmpvar}.value {compare_op} {value}) {{", f" {tmpvar}.index = {self.itervars[-1]}; {tmpvar}.value = {value};", "}", ], ) else: if dtype == torch.float16: self.reduction_prefix.writelines( float16_reduction_prefix(reduction_type) ) self.reduction_prefix.writeline( f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};" ) self.stores.writeline( None, f"{reduction_combine(reduction_type, tmpvar, value)};" ) if name not in V.graph.removed_buffers: var = self.args.output(name) member_name = ".index" if argmax_or_argmin else "" self.reduction_suffix.writeline( name, f"{var}[{cexpr(index)}] = {tmpvar}{member_name};" ) self.cse.store_cache[name] = tmpvar def set_ranges(self, lengths, reduction_lengths): if self.call_ranges: assert self.call_ranges == tuple(lengths) + tuple( reduction_lengths ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" assert self.reduction_depth == len(lengths) else: self.call_ranges = tuple(lengths) + tuple(reduction_lengths) self.ranges = [self.rename_indexing(x) for x in self.call_ranges] self.itervars = [sympy_symbol(f"i{n}") for n in range(len(self.ranges))] self.reduction_depth = len(lengths) return ( self.itervars[: self.reduction_depth], self.itervars[self.reduction_depth :], ) def size_hint(self): return V.graph.sizevars.size_hint(sympy_product(self.call_ranges)) def codegen_loops(self, code, worksharing): threads = parallel_num_threads() loops = [LoopLevel(var, size) for var, size in zip(self.itervars, self.ranges)] loops, reductions = LoopNest(loops[: self.reduction_depth]), LoopNest( loops[self.reduction_depth :] ) reductions.mark_reduction(self.reduction_vars) if codecache.pick_vec_isa(): # TODO(jansel): detect stride-1 dimension and vectorize that if reductions: reductions.loops[-1].simd = True elif loops: loops.loops[-1].simd = True par_depth = 0 reduction_par_depth = 0 if loops: par_depth = self.decide_parallel_depth( self.call_ranges[: self.reduction_depth], threads ) else: reduction_par_depth = self.decide_parallel_depth( self.call_ranges[self.reduction_depth :], threads ) with contextlib.ExitStack() as stack: if par_depth: worksharing.parallel(threads) loops.mark_parallel(par_depth) elif reduction_par_depth: # need to close the worksharing scope to define reduction vars outside it worksharing.close() reductions.mark_parallel(reduction_par_depth) elif threads > 1: if worksharing.single(): stack.enter_context(code.indent()) loops.codegen(code, stack) with contextlib.ExitStack() as stack_outer: if self.reduction_prefix: stack_outer.enter_context(code.indent()) code.splice(self.reduction_prefix) if reduction_par_depth: worksharing.parallel(threads) with contextlib.ExitStack() as stack: reductions.codegen(code, stack) code.splice(self.loads) code.splice(self.compute) code.splice(self.stores) if reduction_par_depth: worksharing.close() code.splice(self.reduction_suffix) def decide_parallel_depth(self, ranges, threads): seq = self.size_hint() par = 1 depth = 0 for expr in ranges: hint = V.graph.sizevars.size_hint(expr) if par >= 2 * threads or par == threads: break if seq // threads < config.cpp.min_chunk_size: # not enough work break depth += 1 par *= hint seq /= hint # if we assume thread number is dynamic, make sure we # have at least one parallel scope and let OMP runtime # to manage the serial vs. parallel. if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: depth = 1 return depth @contextlib.contextmanager def write_to_suffix(self): prior = (self.loads, self.compute, self.stores, self.cse) self.loads = IndentedBuffer() self.compute = IndentedBuffer() self.stores = DeferredIndentedBuffer() self.cse = self.cse.clone() yield self.reduction_suffix.splice(self.loads) self.reduction_suffix.splice(self.compute) self.reduction_suffix.splice(self.stores) (self.loads, self.compute, self.stores, self.cse) = prior class CppVecKernel(CppKernel): overrides = CppVecOverrides def __init__(self, args, num_threads): super(CppVecKernel, self).__init__(args, num_threads) assert codecache.pick_vec_isa() self.simd_nelements = codecache.pick_vec_isa().nelements() self.reduction_omp_dec: Dict[str, str] = {} self.var_vec_buf_map: Dict[str, str] = {} metrics.generated_cpp_vec_kernel_count += 1 def is_single_step_var(self, var: sympy.Symbol, index: sympy.Expr): replacement = {var: var + 1} new_index = sympy_subs(index, replacement) delta = sympy.simplify(new_index - index) return delta == 1 def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr): expanded_index = sympy.expand(index) return not expanded_index.has(var) def transform_index(self, index: sympy.Expr): expanded_index = sympy.expand(index) assert self.simd_nelements assert self.simd_nelements >= 1 most_inner_var = self.itervars[-1] replacement = {most_inner_var: most_inner_var * self.simd_nelements} new_index = sympy_subs(expanded_index, replacement) return new_index def load(self, name: str, index: sympy.Expr): var = self.args.input(name) index = self.rename_indexing(index) expanded_index = sympy.expand(index) new_index = self.transform_index(index) if expanded_index == new_index: line = f"at::vec::Vectorized({var}[{cexpr(index)}])" else: if V.graph.get_dtype(name) in [torch.bool, torch.uint8]: nelements = codecache.pick_vec_isa().nelements() if var not in self.var_vec_buf_map: self.var_vec_buf_map[var] = f"g_tmp_buffer_{var}" self.loads.writeline( f"float {self.var_vec_buf_map[var]}[{nelements}] = {{0}};" ) self.loads.writeline( f"flag_to_float({var} + {cexpr(new_index)}, {self.var_vec_buf_map[var]}, {nelements});" ) line = f"at::vec::Vectorized::loadu({self.var_vec_buf_map[var]})" else: line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" return self.cse.generate(self.loads, line) def store(self, name, index, value, mode=None): assert "buf" in name var = self.args.output(name) index = self.rename_indexing(index) assert mode is None expanded_index = sympy.expand(index) new_index = self.transform_index(index) assert new_index != expanded_index line = f"{value}.store({var} + {cexpr(new_index)});" self.stores.writeline(name, line) def reduction(self, name, dtype, src_dtype, reduction_type, index, value): assert reduction_type in {"max", "min", "sum"} assert dtype == torch.float assert src_dtype == torch.float reduce_map = {"max": "maximum", "min": "minimum"} vec_ns = "at::vec" vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" if reduction_type not in self.reduction_omp_dec: vec_reduc_prefix = "#pragma omp declare reduction(" vec_reduc_prefix += f"{RTYPE_TO_CPP[reduction_type]}:{vec}:" if reduction_type == "sum": vec_reduc_prefix += "omp_out += omp_in" else: vec_reduc_prefix += ( f"omp_out = {vec_ns}::{reduce_map[reduction_type]}(omp_out, omp_in)" ) vec_reduc_prefix += ")" vec_reduc_prefix += " initializer(" vec_reduc_prefix += "omp_priv={{" vec_reduc_prefix += f"{reduction_init(reduction_type, dtype)}" vec_reduc_prefix += "}})" self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type] self.reduction_prefix.writeline(vec_reduc_prefix) tmpvar = self.cse.generate( self.loads, f"reduction {name} {cexpr(index)}", write=False ) tmpvar_vec = f"{tmpvar}_vec" index = self.rename_indexing(index) self.reduction_vars[tmpvar] = reduction_type self.reduction_prefix.writeline( f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};" ) self.reduction_prefix.writeline( f"auto {tmpvar_vec} = at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({tmpvar});" ) self.stores.writeline( None, f"{reduction_combine_vec(reduction_type, tmpvar_vec, value)};" ) reduce_all_body = "{" if reduction_type == "sum": reduce_all_body += "return x + y;" else: reduce_all_body += f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);" reduce_all_body += "}" vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>" self.reduction_suffix.writeline( name, f"{tmpvar} = {vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec});", ) self.cse.store_cache[name] = tmpvar class CppVecKernelChecker(CppVecKernel): def __init__(self, args, num_threads): super(CppVecKernelChecker, self).__init__(args, num_threads) # Since this kernel is only for checker but does not genreate any # code, so we need to decrease the kernel count. metrics.generated_kernel_count -= 1 metrics.generated_cpp_vec_kernel_count -= 1 # Used to recorde the graph wrapper code as the wrapper_code status could be # changed during graph run. self._orig_wrapper_code = None self.simd_vec = True self.fast_vec_list = [] for k, v in CppVecOverrides.__dict__.items(): if isinstance(v, staticmethod): self.fast_vec_list.append(k) self.exit_stack = contextlib.ExitStack() def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr): return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index) def could_vec(self, name: str, index: sympy.Expr): assert self.itervars is not None # Not a loop if len(self.itervars) == 0: return False most_inner_var = self.itervars[-1] return self.is_legal_data_access(most_inner_var, index) def load(self, name: str, index: sympy.Expr): if not V.graph.get_dtype(name) in [ torch.float, torch.float32, torch.bool, torch.uint8, ]: self.simd_vec = False return self.simd_vec index = self.rename_indexing(index) self.simd_vec = self.simd_vec and self.could_vec(name, index) return self.simd_vec def store(self, name, index, value, mode=None): if not V.graph.get_dtype(name) in [torch.float, torch.float32]: self.simd_vec = False return self.simd_vec assert "buf" in name index = self.rename_indexing(index) if mode: self.simd_vec = False return False self.simd_vec = self.simd_vec and self.could_vec(name, index) return self.simd_vec def reduction(self, name, dtype, src_dtype, reduction_type, index, value): if ( dtype == torch.float and src_dtype == torch.float and reduction_type in ["max", "min", "sum"] ): pass else: self.simd_vec = False return self.simd_vec def __exit__(self, exc_type, exc_val, exc_tb): assert self._orig_wrapper_code is not None # Restore the wrapper_code V.graph.wrapper_code = self._orig_wrapper_code self.exit_stack.__exit__(exc_type, exc_val, exc_tb) def __enter__(self): # Recorde the graph wrapper code. The wrapper_code status could be # changed during graph run. Regarding this checker, we also need to # run the graph but we don't expect to change any status that would # impact the code generation. Hence, we record the graph wapper code # and replace it with a dummy warpper_code and then restore to the # original one as long as the checker is finished. self._orig_wrapper_code = V.graph.wrapper_code V.graph.wrapper_code = WrapperCodeGen() class VecCheckerProxy: @staticmethod def __getattr__(name): def inner(*args, **kwargs): if not (name in self.fast_vec_list): self.simd_vec = False return self.simd_vec return inner @staticmethod def load(name: str, index: sympy.Expr): return self.load(name, index) @staticmethod def store(name, index, value, mode=None): return self.store(name, index, value, mode=mode) @staticmethod def reduction(name, dtype, src_dtype, reduction_type, index, value): return self.reduction( name, dtype, src_dtype, reduction_type, index, value ) @staticmethod def constant(val, dtype): supported_dtype = (torch.float32, torch.int32) is_supported_dtype = dtype in (supported_dtype) if not is_supported_dtype: self.simd_vec = False return is_supported_dtype @staticmethod def index_expr(expr, dtype): self.simd_vec = False tmp_var = self.cse.newvar() return tmp_var @staticmethod def indirect_indexing(index_var): self.simd_vec = False return sympy.Symbol(str(index_var)) @staticmethod def masked(mask, body, other): tmp_var = self.cse.newvar() return tmp_var @staticmethod def to_dtype(x, dtype): if dtype != torch.bool: self.simd_vec = False return x self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) return self class CppKernelProxy(CppKernel): def __init__(self, args=None, num_threads=None): super(CppKernelProxy, self).__init__(args, num_threads) self.simd_vec_kernel: CppVecKernel = None self.simd_omp_kernel: CppKernel = None self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float): assert self.picked_vec_isa nelements = self.picked_vec_isa.nelements(dtype) loop_nest.split_most_inner_loop(nelements) loop_with_tail = loop_nest.loops[-1] assert isinstance(loop_with_tail, LoopLevelWithTail) loop_with_tail.main_loop.simd_vec = True loop_with_tail.tail_loop.simd_omp = True # We chope the loop into two cubes by the nelements - main loop and tail loop. # Regarding the main loop, it is straightforward that it could be vectorized with # nelements. But for the tail loop, it still could be vectorized. For example, # if the nelements is 8(256bits), then the tail loop still could be vectorized # as 4(128bits). loop_with_tail.tail_loop.simd_nelements = int(nelements / 2) loop_with_tail.tail_loop.simd_vec = False loop_with_tail.main_loop_body = self.simd_vec_kernel loop_with_tail.tail_loop_body = self.simd_omp_kernel return loop_nest def codegen_loops(self, code, worksharing): threads = parallel_num_threads() if self.simd_vec_kernel is None or not self.picked_vec_isa: assert self.simd_omp_kernel return self.simd_omp_kernel.codegen_loops(code, worksharing) assert self.simd_vec_kernel.itervars == self.simd_omp_kernel.itervars assert self.simd_vec_kernel.ranges == self.simd_omp_kernel.ranges assert ( self.simd_vec_kernel.reduction_vars == self.simd_omp_kernel.reduction_vars ) itervars = self.simd_vec_kernel.itervars rangs = self.simd_vec_kernel.ranges loops = [LoopLevel(var, size) for var, size in zip(itervars, rangs)] assert ( self.simd_vec_kernel.reduction_depth == self.simd_omp_kernel.reduction_depth ) reduction_depth = self.simd_vec_kernel.reduction_depth loops_nest_non_reduce, loops_nest_reduce = LoopNest( loops[:reduction_depth] ), LoopNest(loops[reduction_depth:]) loops_nest_reduce.mark_reduction(self.simd_vec_kernel.reduction_vars) assert self.picked_vec_isa # Do not apply vectorization since the range of most inner is too small. Meanwhile, # If the range of the most inner is less then the codecache.pick_vec_isa().nelements(), # the generated code for some reduction will be as follows that leads to incrrect result. # # LINE01: float tmp1 = 0; # LINE02: auto tmp1_vec = at::vec::Vectorized(tmp1); # LINE03: for(long i1=0; i1<2; i1+=1) # LINE04: { # LINE05: for(long i2=0; i2<0; i2+=1) # LINE06: { # LINE07: auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i0) + (16*i2) + (32*i1)); # LINE08: tmp1_vec += tmp0; # LINE09: } # LINE10: tmp1 = vec_reduce_all([](Vectorized& x, Vectorized&y) {return x + y;}, tmp1_vec); # LINE11: #pragma omp simd simdlen(8) reduction(+:tmp1) # LINE12: for(long i2=0; i2<8; i2+=1) # LINE13: { # LINE14: auto tmp0 = in_ptr0[i2 + (8*i0) + (32*i1)]; # LINE15: tmp1 += tmp0; # LINE16: } # LINE17: } # LINE18: out_ptr3[i0] = tmp1; # # tmp1_vec(LINE02) will always be zero as it is initialized with tmp1 value and the range(LINE05) # is 0. Hence, the LINE10 will always reset tmp1 to 0. But tmp1(LINE01) is global value. So the result # will be incorrect. We skip thie case. most_inner_loop = ( loops_nest_reduce.loops[-1] if loops_nest_reduce else loops_nest_non_reduce.loops[-1] ) main_loop_range = ir.IndexingDiv( most_inner_loop.size, self.picked_vec_isa.nelements() ) loop_interval = sympy.simplify(main_loop_range) # TODO(Eikan): To support dynamic shape. if not loop_interval.is_integer or loop_interval <= 0: metrics.generated_cpp_vec_kernel_count -= 1 return self.simd_omp_kernel.codegen_loops(code, worksharing) # TODO(jansel): detect stride-1 dimension and vectorize that if loops_nest_reduce: loops_nest_reduce.loops[-1].simd = True elif loops_nest_non_reduce: loops_nest_non_reduce.loops[-1].simd = True par_depth = 0 reduction_par_depth = 0 if loops_nest_non_reduce: par_depth = self.simd_vec_kernel.decide_parallel_depth( self.simd_vec_kernel.call_ranges[:reduction_depth], threads ) else: reduction_par_depth = self.simd_vec_kernel.decide_parallel_depth( self.simd_vec_kernel.call_ranges[reduction_depth:], threads ) # If the most inner loop of the reduction will be vectorized, the vectorization # will add a vec variable for reduction. Take the code snippet as an example: # float tmp1 = 0; # for(long i1=0; i1<8; i1+=1) { # auto tmp0 = in_ptr0[i1]; # tmp1 += tmp0; # } # The vectorization will add tmp1_vec for reduction and then the loop will be transformed # as follows. # float tmp1 = 0; # auto tmp1_vec = at::vec::Vectorized(tmp1); # for(long i1=0; i1<1; i1+=1) { # auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i1)); # tmp1_vec += tmp0; # } # tmp1 = at::vec::vec_reduce_all([] # (at::vec::Vectorized& x, at::vec::Vectorized&y) {return x + y;}, # tmp1_vec); # for(long i1=8; i1<8; i1+=1) { # auto tmp0 = in_ptr0[i1]; # tmp1 += tmp0; # } # It means that the vectorization introduce another reduction variable(tmp1_vec). # If the most inner loop of the reduction is not a parallelized but its parent reduction # loop is parallized, the new added reduction variable(tmp1_vec) could not be added # to the parallelized loop reduction. So we skip this case and does not vectorize it. if reduction_par_depth > 0 and reduction_par_depth != len( loops_nest_reduce.loops ): metrics.generated_cpp_vec_kernel_count -= 1 return self.simd_omp_kernel.codegen_loops(code, worksharing) with contextlib.ExitStack() as stack: if par_depth: worksharing.parallel(threads) loops_nest_non_reduce.mark_parallel(par_depth) elif reduction_par_depth: # need to close the worksharing scope to define reduction vars outside it worksharing.close() loops_nest_reduce.mark_parallel(reduction_par_depth) elif threads > 1: if worksharing.single(): stack.enter_context(code.indent()) non_reduce_loops = loops_nest_non_reduce.loops reduce_loops = loops_nest_reduce.loops loop_with_tail: LoopLevelWithTail = None if loops_nest_reduce: self.vectorize_most_inner_loop(loops_nest_reduce) loop_with_tail = loops_nest_reduce.loops[-1] # The most inner loop will be vectorized reduce_loops = reduce_loops[0:-1] else: self.vectorize_most_inner_loop(loops_nest_non_reduce) loop_with_tail = loops_nest_non_reduce.loops[-1] # The most inner loop will be vectorized non_reduce_loops = non_reduce_loops[0:-1] # The reductions loops are always the loop body of non-reduction loops for loop in non_reduce_loops: code.writelines(loop.lines()) stack.enter_context(code.indent()) with contextlib.ExitStack() as stack_outer: if self.simd_vec_kernel.reduction_prefix: stack_outer.enter_context(code.indent()) code.splice(self.simd_vec_kernel.reduction_prefix) if reduction_par_depth: worksharing.parallel(threads) with contextlib.ExitStack() as stack: for loop in reduce_loops: code.writelines(loop.lines()) stack.enter_context(code.indent()) def gen_vectorized_loop(loop, kernel, write_reduction_suffix=False): code.writelines(loop.lines()) with contextlib.ExitStack() as stack: stack.enter_context(code.indent()) code.splice(kernel.loads) code.splice(kernel.compute) code.splice(kernel.stores) if write_reduction_suffix: code.splice(kernel.reduction_suffix) # Regarding the vectorized reduction loop, we need to call reduce_all to to reduce # the vectorize as a single scalar. Hence, we set write_reduction_suffix to True to # gen the code. gen_vectorized_loop( loop_with_tail.main_loop, loop_with_tail.main_loop_body, True ) gen_vectorized_loop( loop_with_tail.tail_loop, loop_with_tail.tail_loop_body, False ) if reduction_par_depth: worksharing.close() code.splice(loop_with_tail.tail_loop_body.reduction_suffix) class CppScheduling: def __init__(self, scheduler): self.scheduler = scheduler self.get_kernel_group() def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) def get_kernel_group(self): from .wrapper import CppWrapperCodeGen if isinstance(V.graph.wrapper_code, CppWrapperCodeGen): self.kernel_group = CppWrapperKernelGroup() else: self.kernel_group = KernelGroup() @staticmethod def can_fuse_horizontal(node1, node2): _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group if vars1 == vars2 and reduce1 == reduce2: return True if reduce1 == () and vars1 == vars2 + reduce2: return True # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? return False @classmethod def can_fuse_vertical(cls, node1, node2): return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction() def can_vec(self, nodes): if not codecache.pick_vec_isa(): return False _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group with CppVecKernelChecker( deepcopy(self.kernel_group.args), parallel_num_threads() ) as kernel_checker: vars, reduction_vars = kernel_checker.set_ranges(group, reduction_group) for node in nodes: if node.group[1] in [ (group, reduction_group), (group + reduction_group, ()), ]: node.run(vars, reduction_vars) else: assert node.group[1] == ( group, (), ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" node.run(vars, ()) return kernel_checker.simd_vec def _codegen_nodes_impl(self, nodes, is_simd_vec=False): """ Turn an set of pre-fused nodes into a C++ kernel. """ kernel_group = self.kernel_group _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group def create_kernel(_is_simd_vec): in_suffix = False with kernel_group.new_kernel(_is_simd_vec) as kernel: vars, reduction_vars = kernel.set_ranges(group, reduction_group) for node in nodes: if node.group[1] in [ (group, reduction_group), (group + reduction_group, ()), ]: assert not in_suffix node.run(vars, reduction_vars) else: in_suffix = True assert node.group[1] == ( group, (), ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" # we can fuse in some extra pointwise into the suffix with kernel.write_to_suffix(): node.run(vars, ()) return kernel org_inplace_buffers_flag = config.inplace_buffers if is_simd_vec: # Create vectorization kernel cpp_vec_kernel = create_kernel(True) # Since a kernel is divided into two parts - vectorization and non-vectorization. # And the two parts share the same global contexts like V.graph.wrapper_code, # V.kernel.args. But the vectorization kernel generation has updated these global # contexts. Hence, the non-vectorization kernel should not do this again to avoid # conext conflict. By now, we only control the config.inplace_buffers. In the future, # we could maintain more contexts. config.inplace_buffers = False # Create non-vectorization kernel cpp_kernel = create_kernel(False) # Restore the inplace_buffers flag config.inplace_buffers = org_inplace_buffers_flag return (cpp_vec_kernel, cpp_kernel) else: return (None, create_kernel(False)) def codegen_nodes(self, nodes): """ Turn an set of pre-fused nodes into a C++ kernel. """ kernel_group = self.kernel_group can_be_simd_vec = self.can_vec(nodes) simd_vec_kernel, simd_omp_kernel = self._codegen_nodes_impl( nodes, can_be_simd_vec ) assert simd_omp_kernel metrics.generated_kernel_count -= 1 # Maitain the metrics kernel count if simd_vec_kernel: metrics.generated_kernel_count -= 1 cpp_kernel_proxy = CppKernelProxy( kernel_group.args, kernel_group.ws.num_threads ) cpp_kernel_proxy.simd_vec_kernel = simd_vec_kernel cpp_kernel_proxy.simd_omp_kernel = simd_omp_kernel kernel_group.finalize_kernel(cpp_kernel_proxy, None) def codegen_sync(self): pass def flush(self): self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) self.get_kernel_group() class KernelGroup: def __init__(self): super().__init__() self.args = KernelArgs() self.loops_code = BracesBuffer() self.ws = WorkSharing(self.loops_code) self.stack = contextlib.ExitStack() self.stack.enter_context(self.ws) self.count = 0 def new_kernel(self, simd_vec=False): if simd_vec: return CppVecKernel(self.args, parallel_num_threads()) else: return CppKernel(self.args, parallel_num_threads()) def finalize_kernel(self, new_kernel, scheduler): self.count += 1 code = self.loops_code ws = self.ws new_kernel.codegen_loops(code, ws) def codegen_define_and_call(self, wrapper): self.stack.close() if self.count == 0: return kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix() arg_defs, call_args, arg_types = self.args.cpp_argdefs() arg_defs = ",\n".ljust(25).join(arg_defs) arg_types = ",".join(arg_types) code = BracesBuffer() # TODO: support kernel profile on other platforms enable_kernel_profile = ( config.cpp.enable_kernel_profile and sys.platform == "linux" ) if enable_kernel_profile: code.writelines(["#include "]) code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})']) with code.indent(): if enable_kernel_profile: graph_id = V.graph.graph_id prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" code.writelines( [ f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' ] ) for old, new in self.args.aliases(): code.writeline(f"auto {old} = {new};") code.splice(self.loops_code) codecache_def = IndentedBuffer() codecache_def.writeline("async_compile.cpp('''") codecache_def.splice(code) codecache_def.writeline("''')") codecache_str = codecache_def.getvalue() # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. codecache_str = codecache_str.replace("#pragma CMT", "//") wrapper.define_kernel(kernel_name, codecache_str) wrapper.load_kernel(kernel_name, code, arg_types) # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args) class CppWrapperKernelGroup(KernelGroup): def __init__(self): super().__init__() self.args = CppWrapperKernelArgs() class WorkSharing: def __init__(self, code): self.code = code self.in_parallel = False self.num_threads = None self.stack = contextlib.ExitStack() def parallel(self, threads): if self.in_parallel and threads != self.num_threads: # wrong number of threads self.close() if not self.in_parallel: self.num_threads = threads self.in_parallel = True if config.cpp.dynamic_threads: self.code.writeline("#pragma omp parallel") else: self.code.writeline(f"#pragma omp parallel num_threads({threads})") self.stack.enter_context(self.code.indent()) def single(self): if self.in_parallel: self.code.writeline("#pragma omp single") return self.in_parallel def close(self): self.stack.close() self.in_parallel = False def __enter__(self): self.stack.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stack.__exit__(exc_type, exc_val, exc_tb) @dataclasses.dataclass class LoopLevel: var: sympy.Expr = None size: sympy.Expr = None offset: sympy.Expr = sympy.Integer(0) steps: sympy.Expr = sympy.Integer(1) parallel: int = 0 simd_omp: bool = False picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 simd_vec: bool = False collapsed: bool = False reduction_vars: Dict[str, str] = None def lines(self): if self.reduction_vars: suffix = "_vec" if self.simd_vec else "" reduction = " " + " ".join( f"reduction({RTYPE_TO_CPP[rtype]}:{var}{suffix})" for var, rtype in self.reduction_vars.items() ) else: reduction = "" simd = ( f"simd simdlen({self.simd_nelements}) " if self.simd_omp and self.simd_nelements > 1 else "" ) if self.parallel: # TODO(jansel): look into chunk size and other schedules line1 = f"#pragma omp for{reduction} " if self.parallel > 1: line1 += f" collapse({self.parallel})" if self.simd_omp: line1 = line1.replace(" for ", f" for {simd}") elif self.simd_vec: line1 = "" elif self.simd_omp: line1 = f"#pragma omp {simd}{reduction}" elif not self.reduction_vars and codecache.is_gcc(): line1 = "#pragma GCC ivdep" else: line1 = "" line2 = f"for({INDEX_TYPE} {self.var}={cexpr(self.offset)}; {self.var}<{cexpr(self.size)}; {self.var}+={cexpr(self.steps)})" if self.collapsed or not line1: return [line2] return [line1, line2] class LoopLevelWithTail(LoopLevel): def __init__(self, main_loop: LoopLevel, tail_loop: LoopLevel): super().__init__() self.main_loop = main_loop self.tail_loop = tail_loop self.main_loop_body = None self.tail_loop_body = None def lines(self): raise AssertionError("Not Implemented") @dataclasses.dataclass class LoopNest: loops: List[LoopLevel] def __bool__(self): return bool(self.loops) def mark_reduction(self, reduction_vars): for loop in self.loops: loop.reduction_vars = reduction_vars def mark_parallel(self, par_depth): loops = self.loops loops[0].parallel = par_depth for i in range(1, par_depth): loops[i].collapsed = True def split_most_inner_loop(self, factor): sympy_factor = sympy.Integer(factor) most_inner_loop = self.loops[-1] # If the most inner loop needs to be collapsed, we need to # exclude it since we need to split it into two loops. Meanwhile, # we still mark it as parallized. if most_inner_loop.collapsed: assert self.loops[0].parallel == len(self.loops) self.loops[0].parallel -= 1 main_loop_range = ir.IndexingDiv(most_inner_loop.size, sympy_factor) main_loop = LoopLevel(most_inner_loop.var, main_loop_range) main_loop.parallel = most_inner_loop.parallel main_loop.collapsed = False main_loop.reduction_vars = most_inner_loop.reduction_vars offset = main_loop_range * sympy_factor tail_loop = LoopLevel(most_inner_loop.var, most_inner_loop.size) tail_loop.offset = offset tail_loop.parallel = most_inner_loop.parallel tail_loop.collapsed = False tail_loop.reduction_vars = most_inner_loop.reduction_vars loop_with_tail = LoopLevelWithTail(main_loop, tail_loop) loop_with_tail.parallel = 0 loop_with_tail.collapsed = False self.loops[-1] = loop_with_tail def codegen(self, code, stack): for loop in self.loops: code.writelines(loop.lines()) stack.enter_context(code.indent()) else: stack.enter_context(code.indent())