pytorch/torch/_inductor/codegen/cpp.py
Wang, Eikan fce6d6b3dc Redefine the simdlen semantic: (#88482)
This PR is targeting to automatically enable vectorization optimization for TorchInductor. It refined the semantics of `config.cpp.simdlen`.

Originally, `None` means to disable vectorization while a specific value means the number of elements to be vectorized once time. But it depends on the data. Regarding 256bit SVE/SIMD ISA for ARM and X86, the `simdlen` should be 16 for Float while 32 for BFloat. Hence, this PR defined the `simdlen` as the bit width. The detailed semantics are as follows.

- **_simdlen = None_**: Automatically determine the SIMD bit width. Detect HW information and pick the proper vectorization ISA. Specific for X86, the priority of AVX512 is higher than AVX2.
- **_simdlen <=1_**: Explicitly disable SIMD
- **_simdlen > 1_**: Explicitly specify the SIMD bit width. It equals the disabled semantic if the bit width does not match the ISA width.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88482
Approved by: https://github.com/jgong5, https://github.com/jansel
2022-11-17 03:27:54 +00:00

1503 lines
51 KiB
Python

import contextlib
import dataclasses
import functools
from copy import deepcopy
from pathlib import Path
from typing import Dict, List
import sympy
import torch
from torch._prims_common import is_float_dtype
from .. import codecache, config, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
from ..utils import sympy_product, sympy_subs, sympy_symbol
from ..virtualized import ops, V
from .common import (
BracesBuffer,
DeferredIndentedBuffer,
ExprPrinter,
IndentedBuffer,
Kernel,
KernelArgs,
OpOverrides,
)
DTYPE_TO_CPP = {
torch.float32: "float",
torch.float64: "double",
torch.float16: "half",
torch.int64: "long",
torch.int32: "int",
torch.int16: "short",
torch.int8: "signed char",
torch.uint8: "unsigned char",
torch.bool: "bool",
torch.bfloat16: "bfloat16",
}
INDEX_TYPE = "long"
RTYPE_TO_CPP = {
"sum": "+",
"min": "min",
"max": "max",
"argmin": "argmin",
"argmax": "argmax",
"any": "||",
}
def reduction_init(reduction_type, dtype):
if reduction_type in ("sum", "any"):
return 0
if reduction_type in {"max", "argmax"}:
return (
f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
if is_float_dtype(dtype)
else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::min()"
)
if reduction_type in {"min", "argmin"}:
return (
f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
if is_float_dtype(dtype)
else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()"
)
raise AssertionError(reduction_type)
def reduction_combine(reduction_type, var, next_value):
if reduction_type == "sum":
return f"{var} += {next_value}"
if reduction_type == "any":
return f"{var} = {var} || {next_value}"
return f"{var} = std::{reduction_type}({var}, {next_value})"
def reduction_combine_vec(reduction_type, var, next_value):
if reduction_type == "max":
return f"{var} = at::vec::maximum({var}, {next_value})"
elif reduction_type == "min":
return f"{var} = at::vec::minimum({var}, {next_value})"
elif reduction_type == "sum":
return f"{var} += {next_value}"
else:
raise NotImplementedError()
index_value_name_counter = 1
def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
global index_value_name_counter
struct_name = f"IndexValue_{index_value_name_counter}"
index_value_name_counter += 1
# A small annoyance, due to it being a little cumbersome to just throw {} into strings
prefix = [
f"struct {struct_name} {{size_t index; {DTYPE_TO_CPP[src_dtype]} value;}};",
f"{struct_name} {tmpvar}{{0, {reduction_init(reduction_type, src_dtype)}}};",
]
if reduction_type == "argmax":
prefix.extend(
[
f"#pragma omp declare reduction(argmax : struct {struct_name} :\\",
" omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\\",
" omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\\",
f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})",
]
)
elif reduction_type == "argmin":
prefix.extend(
[
f"#pragma omp declare reduction(argmin : struct {struct_name} :\\",
" omp_out.value = omp_in.value > omp_out.value ? omp_out.value : omp_in.value,\\",
" omp_out.index = omp_in.value > omp_out.value ? omp_out.index : omp_in.index)\\",
f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})",
]
)
return prefix
def float16_reduction_prefix(rtype):
# TODO: This user-defined reduction uses float16 accumulation for sum. To reduce numerical
# errors, float32 accumulation should be used instead.
assert rtype in (
"sum",
"any",
), f"float16 user-defined reduction only supports 'sum' and 'any' but got {rtype}"
prefix = [
f"#pragma omp declare reduction({RTYPE_TO_CPP[rtype]}:{DTYPE_TO_CPP[torch.float16]}:"
+ f"omp_out = omp_out {RTYPE_TO_CPP[rtype]} omp_in)"
]
return prefix
def parallel_num_threads():
threads = config.cpp.threads
if threads < 1:
threads = torch.get_num_threads()
return threads
@functools.lru_cache()
def cpp_prefix():
path = Path(__file__).parent / "cpp_prefix.h"
with path.open() as f:
_, filename = codecache.write(
f.read(),
"h",
)
return f'#include "{filename}"'
class CppPrinter(ExprPrinter):
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} / {div})"
return f"{x} % {mod}"
def _print_IndexingDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} / {div})"
cexpr = CppPrinter().doprint
class CppVecOverrides(OpOverrides):
"""Map element-wise ops to aten vectorization C++"""
@staticmethod
def add(a, b):
return f"{a} + {b}"
@staticmethod
def sub(a, b):
return f"{a} - {b}"
@staticmethod
def mul(a, b):
return f"{a} * {b}"
@staticmethod
def div(a, b):
return f"{a} / {b}"
@staticmethod
def abs(x):
return f"{x}.abs()"
@staticmethod
def sin(x):
return f"{x}.sin()"
@staticmethod
def cos(x):
return f"{x}.cos()"
@staticmethod
def exp(x):
return f"{x}.exp()"
@staticmethod
def sqrt(x):
return f"{x}.sqrt()"
@staticmethod
def rsqrt(x):
return f"{x}.rsqrt()"
@staticmethod
def pow(a, b):
return f"{a}.pow({b})"
@staticmethod
def log(x):
return f"{x}.log()"
@staticmethod
def round(x):
return f"{x}.round()"
@staticmethod
def floor(x):
return f"{x}.floor()"
@staticmethod
def ceil(x):
return f"{x}.ceil()"
@staticmethod
def trunc(x):
return f"{x}.trunc()"
@staticmethod
def fmod(a, b):
return f"{a}.fmod({b})"
@staticmethod
def lgamma(x):
return f"{x}.lgamma()"
@staticmethod
def logical_and(a, b):
return f"{a} && {b}"
@staticmethod
def logical_or(a, b):
return f"{a} || {b}"
@staticmethod
def tanh(a):
return f"{a}.tanh()"
@staticmethod
def reciprocal(a):
return f"{a}.reciprocal()"
@staticmethod
def constant(val, dtype):
if val == float("inf"):
quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val == float("-inf"):
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val is True or val is False:
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})"
else:
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({repr(val)})"
return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({quote})"
@staticmethod
def relu(x):
return f"at::vec::clamp_min({x}, decltype({x})(0))"
@staticmethod
def sigmoid(x):
return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())"
@staticmethod
def neg(x):
return f"{x}.neg()"
@staticmethod
def floordiv(a, b):
# a and b are integer type
_t = f"decltype({a})"
quot = f"{a} / {b}"
rem = f"{a} % {b}"
return f"(({a} < {_t}(0)) != ({b} < {_t}(0)) ? ({rem} != {_t}(0) ? {quot} - {_t}(1) : {quot}) : {quot})"
@staticmethod
def truncdiv(a, b):
# a and b are integer type
return f"{a} / {b}"
@staticmethod
def minimum(a, b):
return f"at::vec::minimum({a}, {b})"
@staticmethod
def maximum(a, b):
return f"at::vec::maximum({a}, {b})"
@staticmethod
def square(a):
return f"{a}.pow(2)"
@staticmethod
def sign(x):
code = BracesBuffer()
# auto tmp5 = tmp4 < 0 ? -1 : 1;
vec_zero = f"decltype({x})(0)"
vec_one = f"decltype({x})(1)"
blendv = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})"
left = V.kernel.cse.newvar()
code.writeline(f"auto {left} = {blendv};")
# auto tmp6 = tmp4 == 0 ? 0 : tmp5;
blendv = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})"
right = V.kernel.cse.newvar()
code.writeline(f"auto {right} = {blendv};")
result = V.kernel.cse.newvar()
code.writeline(f"auto {result} = {left} - {right};")
V.kernel.compute.splice(code)
return result
class CppOverrides(OpOverrides):
"""Map element-wise ops to C++"""
@staticmethod
def to_dtype(x, dtype):
assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP"
return f"static_cast<{DTYPE_TO_CPP[dtype]}>({x})"
@staticmethod
def abs(x):
return f"std::abs({x})"
@staticmethod
def sin(x):
return f"std::sin({x})"
@staticmethod
def cos(x):
return f"std::cos({x})"
@staticmethod
def exp(x):
# return f"Sleef_expf_u10({x})"
return f"std::exp({x})"
@staticmethod
def sqrt(x):
return f"std::sqrt({x})"
@staticmethod
def rsqrt(x):
return f"1 / std::sqrt({x})"
@staticmethod
def signbit(x):
return f"std::signbit({x})"
@staticmethod
def pow(a, b):
return f"std::pow({a}, {b})"
@staticmethod
def log(x):
return f"std::log({x})"
@staticmethod
def round(x):
return f"std::nearbyint({x})"
@staticmethod
def floor(x):
return f"std::floor({x})"
@staticmethod
def floordiv(a, b):
# a and b are integer type
quot = f"{a} / {b}"
rem = f"{a} % {b}"
return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})"
@staticmethod
def ceil(x):
return f"std::ceil({x})"
@staticmethod
def trunc(x):
return f"std::trunc({x})"
@staticmethod
def truncdiv(a, b):
# a and b are integer type
return f"{a} / {b}"
@staticmethod
def fmod(a, b):
return f"std::fmod({a}, {b})"
@staticmethod
def isinf(x):
return f"std::isinf({x})"
@staticmethod
def isnan(x):
return f"std::isnan({x})"
@staticmethod
def lgamma(x):
return f"std::lgamma({x})"
@staticmethod
def relu(x):
return f"{x} * ({x}>0)"
@staticmethod
def minimum(a, b):
return f"std::min({a}, {b})"
@staticmethod
def maximum(a, b):
return f"std::max({a}, {b})"
@staticmethod
def where(a, b, c):
return f"{a} ? {b} : {c}"
@staticmethod
def mod(a, b):
return f"mod({a}, {b})"
@staticmethod
def constant(val, dtype):
if val == float("inf"):
return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val == float("-inf"):
return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val is True or val is False:
return ops.to_dtype(str(val).lower(), dtype)
return ops.to_dtype(repr(val), dtype)
@staticmethod
def index_expr(expr, dtype):
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
@staticmethod
def masked(mask, body, other):
code = BracesBuffer()
var = V.kernel.cse.newvar()
if other == float("-inf"):
code.writeline(f"float {var} = -std::numeric_limits<float>::infinity();")
elif other == float("inf"):
code.writeline(f"float {var} = std::numeric_limits<float>::infinity();")
elif isinstance(other, float):
code.writeline(f"float {var} = {other};")
else:
code.writeline(f"auto {var} = {other!r};")
code.writeline(f"if({mask})")
with V.kernel.swap_buffers(code), code.indent():
result = body()
code.writeline(f"{var} = {result};")
V.kernel.compute.splice(code)
return var
@staticmethod
def logical_and(a, b):
return f"{a} && {b}"
@staticmethod
def logical_or(a, b):
return f"{a} || {b}"
@staticmethod
def rand(seed: sympy.Expr, offset: sympy.Expr, dtype):
return f"static_cast<{DTYPE_TO_CPP[dtype]}>(normalized_rand_cpu({seed}, {offset}));"
@staticmethod
def randn(seed: sympy.Expr, offset: sympy.Expr, dtype):
return f"static_cast<{DTYPE_TO_CPP[dtype]}>(randn_cpu({seed}, {offset}));"
@staticmethod
def sigmoid(x):
x = ops.exp(f"-{x}")
return f"1 / (1 + {x})"
@staticmethod
def sign(x):
code = BracesBuffer()
# auto tmp5 = tmp4 < 0 ? -1 : 1;
left = V.kernel.cse.newvar()
right = V.kernel.cse.newvar()
result = V.kernel.cse.newvar()
code.writeline(f"auto {left} = {x} > 0 ? 1 : 0;")
code.writeline(f"auto {right} = {x} < 0 ? 1 : 0;")
code.writeline(f"auto {result} = {left} - {right};")
V.kernel.compute.splice(code)
return result
class CppKernel(Kernel):
overrides = CppOverrides
sexpr = cexpr
newvar_prefix = "auto "
suffix = ";"
def __init__(self, args, num_threads):
super(CppKernel, self).__init__(args)
self.call_ranges = None
self.ranges = None
self.itervars = None
self.reduction_depth = None
self.reduction_prefix = IndentedBuffer()
self.reduction_suffix = DeferredIndentedBuffer()
self.reduction_vars = {}
self.num_threads = num_threads # num_threads the kernel specialized for
def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)
index = self.rename_indexing(index)
line = f"{var}[{cexpr(index)}]"
if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16):
line = f"static_cast<float>({line})"
return self.cse.generate(self.loads, line)
def store(self, name, index, value, mode=None):
assert "buf" in name
var = self.args.output(name)
index = self.rename_indexing(index)
if mode is None:
line = f"{var}[{cexpr(index)}] = {value};"
elif mode == "atomic_add":
if not config.cpp.dynamic_threads and self.num_threads == 1:
line = f"{var}[{cexpr(index)}] += {value};"
else:
line = f"atomic_add(&{var}[{cexpr(index)}], {value});"
else:
raise NotImplementedError(f"store mode={mode}")
self.stores.writeline(name, line)
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
argmax_or_argmin = reduction_type in {"argmax", "argmin"}
tmpvar = self.cse.generate(
self.loads, f"reduction {name} {cexpr(index)}", write=False
)
index = self.rename_indexing(index)
self.reduction_vars[tmpvar] = reduction_type
if argmax_or_argmin:
self.reduction_prefix.writelines(
argmax_argmin_prefix(reduction_type, src_dtype, tmpvar)
)
compare_op = "<" if reduction_type == "argmax" else ">"
self.stores.writelines(
None,
[
f"if ({tmpvar}.value {compare_op} {value}) {{",
f" {tmpvar}.index = {self.itervars[-1]}; {tmpvar}.value = {value};",
"}",
],
)
else:
if dtype == torch.float16:
self.reduction_prefix.writelines(
float16_reduction_prefix(reduction_type)
)
self.reduction_prefix.writeline(
f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};"
)
self.stores.writeline(
None, f"{reduction_combine(reduction_type, tmpvar, value)};"
)
if name not in V.graph.removed_buffers:
var = self.args.output(name)
member_name = ".index" if argmax_or_argmin else ""
self.reduction_suffix.writeline(
name, f"{var}[{cexpr(index)}] = {tmpvar}{member_name};"
)
self.cse.store_cache[name] = tmpvar
def set_ranges(self, lengths, reduction_lengths):
if self.call_ranges:
assert self.call_ranges == tuple(lengths) + tuple(
reduction_lengths
), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
assert self.reduction_depth == len(lengths)
else:
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
self.itervars = [sympy_symbol(f"i{n}") for n in range(len(self.ranges))]
self.reduction_depth = len(lengths)
return (
self.itervars[: self.reduction_depth],
self.itervars[self.reduction_depth :],
)
def size_hint(self):
return V.graph.sizevars.size_hint(sympy_product(self.call_ranges))
def codegen_loops(self, code, worksharing):
threads = parallel_num_threads()
loops = [LoopLevel(var, size) for var, size in zip(self.itervars, self.ranges)]
loops, reductions = LoopNest(loops[: self.reduction_depth]), LoopNest(
loops[self.reduction_depth :]
)
reductions.mark_reduction(self.reduction_vars)
if codecache.pick_vec_isa():
# TODO(jansel): detect stride-1 dimension and vectorize that
if reductions:
reductions.loops[-1].simd = True
elif loops:
loops.loops[-1].simd = True
par_depth = 0
reduction_par_depth = 0
if loops:
par_depth = self.decide_parallel_depth(
self.call_ranges[: self.reduction_depth], threads
)
else:
reduction_par_depth = self.decide_parallel_depth(
self.call_ranges[self.reduction_depth :], threads
)
with contextlib.ExitStack() as stack:
if par_depth:
worksharing.parallel(threads)
loops.mark_parallel(par_depth)
elif reduction_par_depth:
# need to close the worksharing scope to define reduction vars outside it
worksharing.close()
reductions.mark_parallel(reduction_par_depth)
elif threads > 1:
if worksharing.single():
stack.enter_context(code.indent())
loops.codegen(code, stack)
with contextlib.ExitStack() as stack_outer:
if self.reduction_prefix:
stack_outer.enter_context(code.indent())
code.splice(self.reduction_prefix)
if reduction_par_depth:
worksharing.parallel(threads)
with contextlib.ExitStack() as stack:
reductions.codegen(code, stack)
code.splice(self.loads)
code.splice(self.compute)
code.splice(self.stores)
if reduction_par_depth:
worksharing.close()
code.splice(self.reduction_suffix)
def decide_parallel_depth(self, ranges, threads):
seq = self.size_hint()
par = 1
depth = 0
for expr in ranges:
hint = V.graph.sizevars.size_hint(expr)
if par >= 2 * threads or par == threads:
break
if seq // threads < config.cpp.min_chunk_size:
# not enough work
break
depth += 1
par *= hint
seq /= hint
# if we assume thread number is dynamic, make sure we
# have at least one parallel scope and let OMP runtime
# to manage the serial vs. parallel.
if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0:
depth = 1
return depth
@contextlib.contextmanager
def write_to_suffix(self):
prior = (self.loads, self.compute, self.stores, self.cse)
self.loads = IndentedBuffer()
self.compute = IndentedBuffer()
self.stores = DeferredIndentedBuffer()
self.cse = self.cse.clone()
yield
self.reduction_suffix.splice(self.loads)
self.reduction_suffix.splice(self.compute)
self.reduction_suffix.splice(self.stores)
(self.loads, self.compute, self.stores, self.cse) = prior
class CppVecKernel(CppKernel):
overrides = CppVecOverrides
def __init__(self, args, num_threads):
super(CppVecKernel, self).__init__(args, num_threads)
assert codecache.pick_vec_isa()
self.simd_nelements = codecache.pick_vec_isa().nelements()
self.reduction_omp_dec: Dict[str, str] = {}
metrics.generated_cpp_vec_kernel_count += 1
def is_single_step_var(self, var: sympy.Symbol, index: sympy.Expr):
replacement = {var: var + 1}
new_index = sympy_subs(index, replacement)
delta = sympy.simplify(new_index - index)
return delta == 1
def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr):
expanded_index = sympy.expand(index)
return not expanded_index.has(var)
def transform_index(self, index: sympy.Expr):
expanded_index = sympy.expand(index)
assert self.simd_nelements
assert self.simd_nelements >= 1
most_inner_var = self.itervars[-1]
replacement = {most_inner_var: most_inner_var * self.simd_nelements}
new_index = sympy_subs(expanded_index, replacement)
return new_index
def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)
index = self.rename_indexing(index)
expanded_index = sympy.expand(index)
new_index = self.transform_index(index)
if expanded_index == new_index:
line = f"at::vec::Vectorized<float>({var}[{cexpr(index)}])"
else:
line = f"at::vec::Vectorized<float>::loadu({var} + {cexpr(new_index)})"
return self.cse.generate(self.loads, line)
def store(self, name, index, value, mode=None):
assert "buf" in name
var = self.args.output(name)
index = self.rename_indexing(index)
assert mode is None
expanded_index = sympy.expand(index)
new_index = self.transform_index(index)
assert new_index != expanded_index
line = f"{value}.store({var} + {cexpr(new_index)});"
self.stores.writeline(name, line)
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
assert reduction_type in {"max", "min", "sum"}
assert dtype == torch.float
assert src_dtype == torch.float
reduce_map = {"max": "maximum", "min": "minimum"}
vec_ns = "at::vec"
vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
if reduction_type not in self.reduction_omp_dec:
vec_reduc_prefix = "#pragma omp declare reduction("
vec_reduc_prefix += f"{RTYPE_TO_CPP[reduction_type]}:{vec}:"
if reduction_type == "sum":
vec_reduc_prefix += "omp_out += omp_in"
else:
vec_reduc_prefix += (
f"omp_out = {vec_ns}::{reduce_map[reduction_type]}(omp_out, omp_in)"
)
vec_reduc_prefix += ")"
vec_reduc_prefix += " initializer("
vec_reduc_prefix += "omp_priv={{"
vec_reduc_prefix += f"{reduction_init(reduction_type, dtype)}"
vec_reduc_prefix += "}})"
self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type]
self.reduction_prefix.writeline(vec_reduc_prefix)
tmpvar = self.cse.generate(
self.loads, f"reduction {name} {cexpr(index)}", write=False
)
tmpvar_vec = f"{tmpvar}_vec"
index = self.rename_indexing(index)
self.reduction_vars[tmpvar] = reduction_type
self.reduction_prefix.writeline(
f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};"
)
self.reduction_prefix.writeline(
f"auto {tmpvar_vec} = at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({tmpvar});"
)
self.stores.writeline(
None, f"{reduction_combine_vec(reduction_type, tmpvar_vec, value)};"
)
reduce_all_body = "{"
if reduction_type == "sum":
reduce_all_body += "return x + y;"
else:
reduce_all_body += f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);"
reduce_all_body += "}"
vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
self.reduction_suffix.writeline(
name,
f"{tmpvar} = {vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec});",
)
self.cse.store_cache[name] = tmpvar
class CppVecKernelChecker(CppVecKernel):
def __init__(self, args, num_threads):
super(CppVecKernelChecker, self).__init__(args, num_threads)
# Since this kernel is only for checker but does not genreate any
# code, so we need to decrease the kernel count.
metrics.generated_kernel_count -= 1
metrics.generated_cpp_vec_kernel_count -= 1
# Used to recorde the graph wrapper code as the wrapper_code status could be
# changed during graph run.
self._orig_wrapper_code = None
self.simd_vec = True
self.fast_vec_list = []
for k, v in CppVecOverrides.__dict__.items():
if isinstance(v, staticmethod):
self.fast_vec_list.append(k)
self.exit_stack = contextlib.ExitStack()
def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr):
return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index)
def could_vec(self, name: str, index: sympy.Expr):
if V.graph.get_dtype(name) is not torch.float:
return False
assert self.itervars is not None
# Not a loop
if len(self.itervars) == 0:
return False
most_inner_var = self.itervars[-1]
return self.is_legal_data_access(most_inner_var, index)
def load(self, name: str, index: sympy.Expr):
index = self.rename_indexing(index)
self.simd_vec = self.simd_vec and self.could_vec(name, index)
return self.simd_vec
def store(self, name, index, value, mode=None):
assert "buf" in name
index = self.rename_indexing(index)
if mode:
self.simd_vec = False
return False
self.simd_vec = self.simd_vec and self.could_vec(name, index)
return self.simd_vec
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
if (
dtype == torch.float
and src_dtype == torch.float
and reduction_type in ["max", "min", "sum"]
):
pass
else:
self.simd_vec = False
return self.simd_vec
def __exit__(self, exc_type, exc_val, exc_tb):
assert self._orig_wrapper_code is not None
# Restore the wrapper_code
V.graph.wrapper_code = self._orig_wrapper_code
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
def __enter__(self):
# Recorde the graph wrapper code. The wrapper_code status could be
# changed during graph run. Regarding this checker, we also need to
# run the graph but we don't expect to change any status that would
# impact the code generation. Hence, we record the graph wapper code
# and replace it with a dummy warpper_code and then restore to the
# original one as long as the checker is finished.
self._orig_wrapper_code = V.graph.wrapper_code
V.graph.wrapper_code = WrapperCodeGen()
class VecCheckerProxy:
@staticmethod
def __getattr__(name):
def inner(*args, **kwargs):
if not (name in self.fast_vec_list):
self.simd_vec = False
return self.simd_vec
return inner
@staticmethod
def load(name: str, index: sympy.Expr):
return self.load(name, index)
@staticmethod
def store(name, index, value, mode=None):
return self.store(name, index, value, mode=mode)
@staticmethod
def reduction(name, dtype, src_dtype, reduction_type, index, value):
return self.reduction(
name, dtype, src_dtype, reduction_type, index, value
)
@staticmethod
def constant(val, dtype):
supported_dtype = (torch.float32, torch.int32)
is_supported_dtype = dtype in (supported_dtype)
if not is_supported_dtype:
self.simd_vec = False
return is_supported_dtype
@staticmethod
def index_expr(expr, dtype):
self.simd_vec = False
return self.cse.newvar()
@staticmethod
def indirect_indexing(index_var):
return sympy.Symbol(str(index_var))
@staticmethod
def masked(mask, body, other):
return V.kernel.cse.newvar()
self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy()))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self
class CppKernelProxy(CppKernel):
def __init__(self, args=None, num_threads=None):
super(CppKernelProxy, self).__init__(args, num_threads)
self.simd_vec_kernel: CppVecKernel = None
self.simd_omp_kernel: CppKernel = None
self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float):
assert self.picked_vec_isa
nelements = self.picked_vec_isa.nelements(dtype)
loop_nest.split_most_inner_loop(nelements)
loop_with_tail = loop_nest.loops[-1]
assert isinstance(loop_with_tail, LoopLevelWithTail)
loop_with_tail.main_loop.simd_vec = True
loop_with_tail.tail_loop.simd_omp = True
# We chope the loop into two cubes by the nelements - main loop and tail loop.
# Regarding the main loop, it is straightforward that it could be vectorized with
# nelements. But for the tail loop, it still could be vectorized. For example,
# if the nelements is 8(256bits), then the tail loop still could be vectorized
# as 4(128bits).
loop_with_tail.tail_loop.simd_nelements = int(nelements / 2)
loop_with_tail.tail_loop.simd_vec = False
loop_with_tail.main_loop_body = self.simd_vec_kernel
loop_with_tail.tail_loop_body = self.simd_omp_kernel
return loop_nest
def codegen_loops(self, code, worksharing):
threads = parallel_num_threads()
if self.simd_vec_kernel is None or not self.picked_vec_isa:
assert self.simd_omp_kernel
return self.simd_omp_kernel.codegen_loops(code, worksharing)
assert self.simd_vec_kernel.itervars == self.simd_omp_kernel.itervars
assert self.simd_vec_kernel.ranges == self.simd_omp_kernel.ranges
assert (
self.simd_vec_kernel.reduction_vars == self.simd_omp_kernel.reduction_vars
)
itervars = self.simd_vec_kernel.itervars
rangs = self.simd_vec_kernel.ranges
loops = [LoopLevel(var, size) for var, size in zip(itervars, rangs)]
assert (
self.simd_vec_kernel.reduction_depth == self.simd_omp_kernel.reduction_depth
)
reduction_depth = self.simd_vec_kernel.reduction_depth
loops_nest_non_reduce, loops_nest_reduce = LoopNest(
loops[:reduction_depth]
), LoopNest(loops[reduction_depth:])
loops_nest_reduce.mark_reduction(self.simd_vec_kernel.reduction_vars)
assert self.picked_vec_isa
# Do not apply vectorization since the range of most inner is too small. Meanwhile,
# If the range of the most inner is less then the codecache.pick_vec_isa().nelements(),
# the generated code for some reduction will be as follows that leads to incrrect result.
#
# LINE01: float tmp1 = 0;
# LINE02: auto tmp1_vec = at::vec::Vectorized<float>(tmp1);
# LINE03: for(long i1=0; i1<2; i1+=1)
# LINE04: {
# LINE05: for(long i2=0; i2<0; i2+=1)
# LINE06: {
# LINE07: auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + (8*i0) + (16*i2) + (32*i1));
# LINE08: tmp1_vec += tmp0;
# LINE09: }
# LINE10: tmp1 = vec_reduce_all<float>([](Vectorized<float>& x, Vectorized<float>&y) {return x + y;}, tmp1_vec);
# LINE11: #pragma omp simd simdlen(8) reduction(+:tmp1)
# LINE12: for(long i2=0; i2<8; i2+=1)
# LINE13: {
# LINE14: auto tmp0 = in_ptr0[i2 + (8*i0) + (32*i1)];
# LINE15: tmp1 += tmp0;
# LINE16: }
# LINE17: }
# LINE18: out_ptr3[i0] = tmp1;
#
# tmp1_vec(LINE02) will always be zero as it is initialized with tmp1 value and the range(LINE05)
# is 0. Hence, the LINE10 will always reset tmp1 to 0. But tmp1(LINE01) is global value. So the result
# will be incorrect. We skip thie case.
most_inner_loop = (
loops_nest_reduce.loops[-1]
if loops_nest_reduce
else loops_nest_non_reduce.loops[-1]
)
main_loop_range = ir.IndexingDiv(
most_inner_loop.size, self.picked_vec_isa.nelements()
)
loop_interval = sympy.simplify(main_loop_range)
# TODO(Eikan): To support dynamic shape.
if not loop_interval.is_integer or loop_interval <= 0:
metrics.generated_cpp_vec_kernel_count -= 1
return self.simd_omp_kernel.codegen_loops(code, worksharing)
# TODO(jansel): detect stride-1 dimension and vectorize that
if loops_nest_reduce:
loops_nest_reduce.loops[-1].simd = True
elif loops_nest_non_reduce:
loops_nest_non_reduce.loops[-1].simd = True
par_depth = 0
reduction_par_depth = 0
if loops_nest_non_reduce:
par_depth = self.simd_vec_kernel.decide_parallel_depth(
self.simd_vec_kernel.call_ranges[:reduction_depth], threads
)
else:
reduction_par_depth = self.simd_vec_kernel.decide_parallel_depth(
self.simd_vec_kernel.call_ranges[reduction_depth:], threads
)
# If the most inner loop of the reduction will be vectorized, the vectorization
# will add a vec variable for reduction. Take the code snippet as an example:
# float tmp1 = 0;
# for(long i1=0; i1<8; i1+=1) {
# auto tmp0 = in_ptr0[i1];
# tmp1 += tmp0;
# }
# The vectorization will add tmp1_vec for reduction and then the loop will be transformed
# as follows.
# float tmp1 = 0;
# auto tmp1_vec = at::vec::Vectorized<float>(tmp1);
# for(long i1=0; i1<1; i1+=1) {
# auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + (8*i1));
# tmp1_vec += tmp0;
# }
# tmp1 = at::vec::vec_reduce_all<float>([]
# (at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;},
# tmp1_vec);
# for(long i1=8; i1<8; i1+=1) {
# auto tmp0 = in_ptr0[i1];
# tmp1 += tmp0;
# }
# It means that the vectorization introduce another reduction variable(tmp1_vec).
# If the most inner loop of the reduction is not a parallelized but its parent reduction
# loop is parallized, the new added reduction variable(tmp1_vec) could not be added
# to the parallelized loop reduction. So we skip this case and does not vectorize it.
if reduction_par_depth > 0 and reduction_par_depth != len(
loops_nest_reduce.loops
):
return self.simd_omp_kernel.codegen_loops(code, worksharing)
with contextlib.ExitStack() as stack:
if par_depth:
worksharing.parallel(threads)
loops_nest_non_reduce.mark_parallel(par_depth)
elif reduction_par_depth:
# need to close the worksharing scope to define reduction vars outside it
worksharing.close()
loops_nest_reduce.mark_parallel(reduction_par_depth)
elif threads > 1:
if worksharing.single():
stack.enter_context(code.indent())
non_reduce_loops = loops_nest_non_reduce.loops
reduce_loops = loops_nest_reduce.loops
loop_with_tail: LoopLevelWithTail = None
if loops_nest_reduce:
self.vectorize_most_inner_loop(loops_nest_reduce)
loop_with_tail = loops_nest_reduce.loops[-1]
# The most inner loop will be vectorized
reduce_loops = reduce_loops[0:-1]
else:
self.vectorize_most_inner_loop(loops_nest_non_reduce)
loop_with_tail = loops_nest_non_reduce.loops[-1]
# The most inner loop will be vectorized
non_reduce_loops = non_reduce_loops[0:-1]
# The reductions loops are always the loop body of non-reduction loops
for loop in non_reduce_loops:
code.writelines(loop.lines())
stack.enter_context(code.indent())
with contextlib.ExitStack() as stack_outer:
if self.simd_vec_kernel.reduction_prefix:
stack_outer.enter_context(code.indent())
code.splice(self.simd_vec_kernel.reduction_prefix)
if reduction_par_depth:
worksharing.parallel(threads)
with contextlib.ExitStack() as stack:
for loop in reduce_loops:
code.writelines(loop.lines())
stack.enter_context(code.indent())
def gen_vectorized_loop(loop, kernel, write_reduction_suffix=False):
code.writelines(loop.lines())
with contextlib.ExitStack() as stack:
stack.enter_context(code.indent())
code.splice(kernel.loads)
code.splice(kernel.compute)
code.splice(kernel.stores)
if write_reduction_suffix:
code.splice(kernel.reduction_suffix)
# Regarding the vectorized reduction loop, we need to call reduce_all to to reduce
# the vectorize as a single scalar. Hence, we set write_reduction_suffix to True to
# gen the code.
gen_vectorized_loop(
loop_with_tail.main_loop, loop_with_tail.main_loop_body, True
)
gen_vectorized_loop(
loop_with_tail.tail_loop, loop_with_tail.tail_loop_body, False
)
if reduction_par_depth:
worksharing.close()
code.splice(loop_with_tail.tail_loop_body.reduction_suffix)
class CppScheduling:
def __init__(self, scheduler):
self.scheduler = scheduler
self.kernel_group = KernelGroup()
def group_fn(self, sizes):
return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
@staticmethod
def can_fuse_horizontal(node1, node2):
_, (vars1, reduce1) = node1.group
_, (vars2, reduce2) = node2.group
if vars1 == vars2 and reduce1 == reduce2:
return True
if reduce1 == () and vars1 == vars2 + reduce2:
return True
# TODO(jansel): allow fusion pointwise (vars1, ()) suffix?
return False
@classmethod
def can_fuse_vertical(cls, node1, node2):
return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction()
def can_vec(self, nodes):
if not codecache.pick_vec_isa():
return False
_, (group, reduction_group) = max(
nodes, key=lambda x: int(x.is_reduction())
).group
with CppVecKernelChecker(
deepcopy(self.kernel_group.args), parallel_num_threads()
) as kernel_checker:
vars, reduction_vars = kernel_checker.set_ranges(group, reduction_group)
for node in nodes:
if node.group[1] in [
(group, reduction_group),
(group + reduction_group, ()),
]:
node.run(vars, reduction_vars)
else:
assert node.group[1] == (
group,
(),
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
node.run(vars, ())
return kernel_checker.simd_vec
def _codegen_nodes_impl(self, nodes, is_simd_vec=False):
"""
Turn an set of pre-fused nodes into a C++ kernel.
"""
kernel_group = self.kernel_group
_, (group, reduction_group) = max(
nodes, key=lambda x: int(x.is_reduction())
).group
def create_kernel(_is_simd_vec):
in_suffix = False
with kernel_group.new_kernel(_is_simd_vec) as kernel:
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
for node in nodes:
if node.group[1] in [
(group, reduction_group),
(group + reduction_group, ()),
]:
assert not in_suffix
node.run(vars, reduction_vars)
else:
in_suffix = True
assert node.group[1] == (
group,
(),
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
# we can fuse in some extra pointwise into the suffix
with kernel.write_to_suffix():
node.run(vars, ())
return kernel
org_inplace_buffers_flag = config.inplace_buffers
if is_simd_vec:
# Create vectorization kernel
cpp_vec_kernel = create_kernel(True)
# Since a kernel is divided into two parts - vectorization and non-vectorization.
# And the two parts share the same global contexts like V.graph.wrapper_code,
# V.kernel.args. But the vectorization kernel generation has updated these global
# contexts. Hence, the non-vectorization kernel should not do this again to avoid
# conext conflict. By now, we only control the config.inplace_buffers. In the future,
# we could maintain more contexts.
config.inplace_buffers = False
# Create non-vectorization kernel
cpp_kernel = create_kernel(False)
# Restore the inplace_buffers flag
config.inplace_buffers = org_inplace_buffers_flag
return (cpp_vec_kernel, cpp_kernel)
else:
return (None, create_kernel(False))
def codegen_nodes(self, nodes):
"""
Turn an set of pre-fused nodes into a C++ kernel.
"""
kernel_group = self.kernel_group
can_be_simd_vec = self.can_vec(nodes)
simd_vec_kernel, simd_omp_kernel = self._codegen_nodes_impl(
nodes, can_be_simd_vec
)
assert simd_omp_kernel
metrics.generated_kernel_count -= 1
# Maitain the metrics kernel count
if simd_vec_kernel:
metrics.generated_kernel_count -= 1
cpp_kernel_proxy = CppKernelProxy(
kernel_group.args, kernel_group.ws.num_threads
)
cpp_kernel_proxy.simd_vec_kernel = simd_vec_kernel
cpp_kernel_proxy.simd_omp_kernel = simd_omp_kernel
kernel_group.finalize_kernel(cpp_kernel_proxy, None)
def flush(self):
self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
self.kernel_group = KernelGroup()
class KernelGroup:
def __init__(self):
super().__init__()
self.args = KernelArgs()
self.loops_code = BracesBuffer()
self.ws = WorkSharing(self.loops_code)
self.stack = contextlib.ExitStack()
self.stack.enter_context(self.ws)
self.count = 0
def new_kernel(self, simd_vec=False):
if simd_vec:
return CppVecKernel(self.args, parallel_num_threads())
else:
return CppKernel(self.args, parallel_num_threads())
def finalize_kernel(self, new_kernel, scheduler):
self.count += 1
code = self.loops_code
ws = self.ws
new_kernel.codegen_loops(code, ws)
def codegen_define_and_call(self, wrapper):
self.stack.close()
if self.count == 0:
return
arg_defs, call_args = self.args.cpp_argdefs()
arg_defs = ",\n".ljust(25).join(arg_defs)
code = BracesBuffer()
code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})'])
with code.indent():
for old, new in self.args.aliases():
code.writeline(f"auto {old} = {new};")
code.splice(self.loops_code)
codecache_def = IndentedBuffer()
codecache_def.writeline("async_compile.cpp('''")
codecache_def.splice(code)
codecache_def.writeline("''')")
kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix()
codecache_str = codecache_def.getvalue()
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
codecache_str = codecache_str.replace("#pragma CMT", "//")
wrapper.define_kernel(kernel_name, codecache_str)
# generate the code to call this
wrapper.writeline(
"{}({})".format(kernel_name, ", ".join(call_args)),
)
class WorkSharing:
def __init__(self, code):
self.code = code
self.in_parallel = False
self.num_threads = None
self.stack = contextlib.ExitStack()
def parallel(self, threads):
if self.in_parallel and threads != self.num_threads:
# wrong number of threads
self.close()
if not self.in_parallel:
self.num_threads = threads
self.in_parallel = True
if config.cpp.dynamic_threads:
self.code.writeline("#pragma omp parallel")
else:
self.code.writeline(f"#pragma omp parallel num_threads({threads})")
self.stack.enter_context(self.code.indent())
def single(self):
if self.in_parallel:
self.code.writeline("#pragma omp single")
return self.in_parallel
def close(self):
self.stack.close()
self.in_parallel = False
def __enter__(self):
self.stack.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stack.__exit__(exc_type, exc_val, exc_tb)
@dataclasses.dataclass
class LoopLevel:
var: sympy.Expr = None
size: sympy.Expr = None
offset: sympy.Expr = sympy.Integer(0)
steps: sympy.Expr = sympy.Integer(1)
parallel: int = 0
simd_omp: bool = False
picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0
simd_vec: bool = False
collapsed: bool = False
reduction_vars: Dict[str, str] = None
def lines(self):
if self.reduction_vars:
suffix = "_vec" if self.simd_vec else ""
reduction = " " + " ".join(
f"reduction({RTYPE_TO_CPP[rtype]}:{var}{suffix})"
for var, rtype in self.reduction_vars.items()
)
else:
reduction = ""
simd = (
f"simd simdlen({self.simd_nelements}) "
if self.simd_omp and self.simd_nelements > 1
else ""
)
if self.parallel:
# TODO(jansel): look into chunk size and other schedules
line1 = f"#pragma omp for{reduction} "
if self.parallel > 1:
line1 += f" collapse({self.parallel})"
if self.simd_omp:
line1 = line1.replace(" for ", f" for {simd}")
elif self.simd_vec:
line1 = ""
elif self.simd_omp:
line1 = f"#pragma omp {simd}{reduction}"
elif not self.reduction_vars and codecache.is_gcc():
line1 = "#pragma GCC ivdep"
else:
line1 = ""
line2 = f"for({INDEX_TYPE} {self.var}={cexpr(self.offset)}; {self.var}<{cexpr(self.size)}; {self.var}+={cexpr(self.steps)})"
if self.collapsed or not line1:
return [line2]
return [line1, line2]
class LoopLevelWithTail(LoopLevel):
def __init__(self, main_loop: LoopLevel, tail_loop: LoopLevel):
super().__init__()
self.main_loop = main_loop
self.tail_loop = tail_loop
self.main_loop_body = None
self.tail_loop_body = None
def lines(self):
raise AssertionError("Not Implemented")
@dataclasses.dataclass
class LoopNest:
loops: List[LoopLevel]
def __bool__(self):
return bool(self.loops)
def mark_reduction(self, reduction_vars):
for loop in self.loops:
loop.reduction_vars = reduction_vars
def mark_parallel(self, par_depth):
loops = self.loops
loops[0].parallel = par_depth
for i in range(1, par_depth):
loops[i].collapsed = True
def split_most_inner_loop(self, factor):
sympy_factor = sympy.Integer(factor)
most_inner_loop = self.loops[-1]
# If the most inner loop needs to be collapsed, we need to
# exclude it since we need to split it into two loops. Meanwhile,
# we still mark it as parallized.
if most_inner_loop.collapsed:
assert self.loops[0].parallel == len(self.loops)
self.loops[0].parallel -= 1
main_loop_range = ir.IndexingDiv(most_inner_loop.size, sympy_factor)
main_loop = LoopLevel(most_inner_loop.var, main_loop_range)
main_loop.parallel = most_inner_loop.parallel
main_loop.collapsed = False
main_loop.reduction_vars = most_inner_loop.reduction_vars
offset = main_loop_range * sympy_factor
tail_loop = LoopLevel(most_inner_loop.var, most_inner_loop.size)
tail_loop.offset = offset
tail_loop.parallel = most_inner_loop.parallel
tail_loop.collapsed = False
tail_loop.reduction_vars = most_inner_loop.reduction_vars
loop_with_tail = LoopLevelWithTail(main_loop, tail_loop)
loop_with_tail.parallel = 0
loop_with_tail.collapsed = False
self.loops[-1] = loop_with_tail
def codegen(self, code, stack):
for loop in self.loops:
code.writelines(loop.lines())
stack.enter_context(code.indent())
else:
stack.enter_context(code.indent())