mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor cpp] vectorize embedding lookup (#114062)
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114062
Approved by: https://github.com/jansel
This commit is contained in:
parent
3e1abde46d
commit
a0e3321f0c
|
|
@ -1299,6 +1299,7 @@ class CPUReproTests(TestCase):
|
|||
cpp_op_list.append(k)
|
||||
|
||||
diff = [
|
||||
"constant",
|
||||
"index_expr",
|
||||
"signbit",
|
||||
"isinf",
|
||||
|
|
@ -2612,6 +2613,23 @@ class CPUReproTests(TestCase):
|
|||
x = torch.randn(1, 39, 1, 18, 17)
|
||||
self.common(m, (x,))
|
||||
|
||||
def test_embedding_vec(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(64, 128)
|
||||
|
||||
def forward(self, idx, x):
|
||||
return self.emb(idx) + x
|
||||
|
||||
idx = torch.randint(0, 64, (4, 32))
|
||||
x = torch.randn(4, 32, 128)
|
||||
m = M().eval()
|
||||
with torch.no_grad():
|
||||
metrics.reset()
|
||||
self.common(m, (idx, x))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import math
|
|||
import re
|
||||
import sys
|
||||
from copy import copy, deepcopy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -133,6 +133,19 @@ DTYPE_LOWP_FP = [
|
|||
]
|
||||
|
||||
|
||||
def value_to_cpp(value, cpp_type):
|
||||
if value == float("-inf"):
|
||||
return f"-std::numeric_limits<{cpp_type}>::infinity()"
|
||||
elif value == float("inf"):
|
||||
return f"std::numeric_limits<{cpp_type}>::infinity()"
|
||||
elif isinstance(value, bool):
|
||||
return f"static_cast<{cpp_type}>({str(value).lower()})"
|
||||
elif math.isnan(value):
|
||||
return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
|
||||
else:
|
||||
return f"static_cast<{cpp_type}>({repr(value)})"
|
||||
|
||||
|
||||
def reduction_init(reduction_type, dtype):
|
||||
if dtype in DTYPE_LOWP_FP:
|
||||
# Since load promotes all half-precision inputs to float, the initial
|
||||
|
|
@ -436,6 +449,54 @@ def get_current_node_opt_ctx() -> OptimizationContext:
|
|||
return get_opt_ctx(V.interpreter.current_node)
|
||||
|
||||
|
||||
class CppCSEVariable(CSEVariable):
|
||||
def __init__(self, name, bounds: ValueRanges):
|
||||
super().__init__(name, bounds)
|
||||
self.is_vec = False
|
||||
self.dtype: Optional[torch.dtype] = None
|
||||
self.dependent_itervars: Set[sympy.Symbol] = set()
|
||||
|
||||
def update_on_args(self, name, args, kwargs):
|
||||
if name == "load":
|
||||
# args[1] is index
|
||||
self._set_dependent_itervars(args[1])
|
||||
else:
|
||||
# propagate relevant itervars and is_vec from args
|
||||
self.dependent_itervars.update(
|
||||
*[
|
||||
arg.dependent_itervars
|
||||
for arg in args
|
||||
if isinstance(arg, CppCSEVariable)
|
||||
]
|
||||
)
|
||||
if name == "index_expr":
|
||||
self._set_dependent_itervars(args[0])
|
||||
if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
|
||||
self.is_vec = True
|
||||
if (
|
||||
hasattr(V.interpreter, "current_node")
|
||||
and get_current_node_opt_ctx() is not None
|
||||
):
|
||||
self.dtype = get_current_node_opt_ctx().dtype
|
||||
|
||||
def _set_dependent_itervars(self, index: sympy.Expr):
|
||||
"""
|
||||
Set the relevant itervars for this variable based on the `index` expression.
|
||||
This includes the itervars directly used in the `index` as well as relevant itervars
|
||||
of other cse variables used in the `index`.
|
||||
"""
|
||||
for s in index.free_symbols:
|
||||
if s in V.kernel.itervars:
|
||||
self.dependent_itervars.add(s)
|
||||
elif s.name in V.kernel.cse.varname_map:
|
||||
self.dependent_itervars.update(
|
||||
V.kernel.cse.varname_map[s.name].dependent_itervars
|
||||
)
|
||||
|
||||
def depends_on(self, itervar: sympy.Symbol):
|
||||
return itervar in self.dependent_itervars
|
||||
|
||||
|
||||
class CppOverrides(OpOverrides):
|
||||
"""Map element-wise ops to C++"""
|
||||
|
||||
|
|
@ -672,22 +733,20 @@ class CppOverrides(OpOverrides):
|
|||
|
||||
@staticmethod
|
||||
def constant(val, dtype):
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
assert opt_ctx and opt_ctx.dtype is not None
|
||||
dtype = opt_ctx.dtype
|
||||
if dtype in DTYPE_LOWP_FP:
|
||||
# Since load promotes all half-precision inputs to float, constants
|
||||
# must be promoted as well
|
||||
dtype = torch.float32
|
||||
if val == float("inf"):
|
||||
return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
|
||||
elif val == float("-inf"):
|
||||
return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
|
||||
elif math.isnan(val):
|
||||
return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()"
|
||||
elif val is True or val is False:
|
||||
return ops.to_dtype(str(val).lower(), dtype)
|
||||
return ops.to_dtype(repr(val), dtype)
|
||||
return value_to_cpp(val, DTYPE_TO_CPP[dtype])
|
||||
|
||||
@staticmethod
|
||||
def index_expr(expr, dtype):
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
assert opt_ctx and opt_ctx.dtype is not None
|
||||
dtype = opt_ctx.dtype
|
||||
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -704,19 +763,7 @@ class CppOverrides(OpOverrides):
|
|||
V.kernel.compute.splice(code)
|
||||
|
||||
# Use the lambda's return type as the type of other
|
||||
type = f"decltype({body_var}())"
|
||||
|
||||
if other == float("-inf"):
|
||||
other_code = f"-std::numeric_limits<{type}>::infinity()"
|
||||
elif other == float("inf"):
|
||||
other_code = f"std::numeric_limits<{type}>::infinity()"
|
||||
elif isinstance(other, bool):
|
||||
other_code = f"static_cast<{type}>({str(other).lower()})"
|
||||
elif math.isnan(other):
|
||||
other_code = f"std::numeric_limits<{type}>::quiet_NaN()"
|
||||
else:
|
||||
other_code = f"static_cast<{type}>({repr(other)})"
|
||||
|
||||
other_code = value_to_cpp(other, f"decltype({body_var}())")
|
||||
return f"{mask} ? {body_var}() : {other_code}"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -794,6 +841,54 @@ class CppOverrides(OpOverrides):
|
|||
class CppVecOverrides(CppOverrides):
|
||||
"""Map element-wise ops to aten vectorization C++"""
|
||||
|
||||
def __new__(cls, *args, **kargs):
|
||||
self = super().__new__(cls)
|
||||
|
||||
def wrap(func):
|
||||
# `CppVecKernel` generates both scalar ops and vector ops according to
|
||||
# whether the inputs are scalars or vectors while all ops in `CppVecOverrides`
|
||||
# (except for "masked") assume the inputs are vectors. We wrap the ops in
|
||||
# `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to
|
||||
# `CppOverrides` when all inputs are scalars.
|
||||
#
|
||||
# Inputs to ops.masked are handled separately in its own function due to
|
||||
# the need of recurive handling of masked body.
|
||||
def wrapper(*args, **kwargs):
|
||||
has_scalar = any(
|
||||
not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
|
||||
)
|
||||
has_vector = any(
|
||||
arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
|
||||
)
|
||||
new_args = list(args)
|
||||
if has_scalar and has_vector:
|
||||
# broadcast scalar args to vector if needed
|
||||
new_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, CppCSEVariable) and not arg.is_vec:
|
||||
assert isinstance(V.kernel, CppVecKernel)
|
||||
new_arg = V.kernel.broadcast(arg)
|
||||
new_args.append(new_arg)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
if has_vector:
|
||||
return func(*new_args, **kwargs)
|
||||
else:
|
||||
# fallback to scalar ops
|
||||
scalar_ops = super(CppVecOverrides, self)
|
||||
scalar_func = getattr(
|
||||
scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined]
|
||||
)
|
||||
assert scalar_func is not None
|
||||
return scalar_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
for name, method in vars(cls).items():
|
||||
if getattr(method, "__class__", None) == staticmethod and name != "masked":
|
||||
setattr(self, name, wrap(method.__func__))
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return f"{a} + {b}"
|
||||
|
|
@ -1006,28 +1101,6 @@ class CppVecOverrides(CppOverrides):
|
|||
vec_one = f"decltype({x})(1)"
|
||||
return f"({x} + ({x}*{x} - {vec_one}).sqrt()).log()"
|
||||
|
||||
@staticmethod
|
||||
def constant(val, dtype):
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
assert opt_ctx
|
||||
proposed_dtype = opt_ctx.dtype
|
||||
assert proposed_dtype in [
|
||||
torch.float,
|
||||
torch.int32,
|
||||
]
|
||||
if val == float("inf"):
|
||||
quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
|
||||
elif val == float("-inf"):
|
||||
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
|
||||
elif math.isnan(val):
|
||||
quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::quiet_NaN()"
|
||||
elif val is True or val is False:
|
||||
quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({str(val).lower()})"
|
||||
else:
|
||||
quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({repr(val)})"
|
||||
|
||||
return f"at::vec::Vectorized<{DTYPE_TO_CPP[proposed_dtype]}>({quote})"
|
||||
|
||||
@staticmethod
|
||||
def relu(x):
|
||||
bug = config.cpp.inject_relu_bug_TESTING_ONLY
|
||||
|
|
@ -1159,32 +1232,24 @@ class CppVecOverrides(CppOverrides):
|
|||
code.writeline(";")
|
||||
V.kernel.compute.splice(code)
|
||||
|
||||
if other == float("-inf"):
|
||||
other_code = (
|
||||
"at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())"
|
||||
)
|
||||
elif other == float("inf"):
|
||||
other_code = (
|
||||
"at::vec::Vectorized<float>(std::numeric_limits<float>::infinity())"
|
||||
)
|
||||
elif math.isnan(other):
|
||||
other_code = (
|
||||
"at::vec::Vectorized<float>(std::numeric_limits<float>::quiet_NaN())"
|
||||
other_code = value_to_cpp(other, "float")
|
||||
other_code_vec = f"at::vec::Vectorized<float>({other_code})"
|
||||
|
||||
if result.is_vec:
|
||||
type = f"decltype({var}())"
|
||||
float_mask = f"to_float_mask({new_mask})"
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
f"{type}::blendv({other_code_vec}, {var}(), {float_mask})",
|
||||
)
|
||||
else:
|
||||
other_code = f"at::vec::Vectorized<float>({other!r})"
|
||||
type = f"decltype({var}())"
|
||||
float_mask = f"to_float_mask({new_mask})"
|
||||
return f"{type}::blendv({other_code}, {var}(), {float_mask})"
|
||||
|
||||
@staticmethod
|
||||
def index_expr(expr, dtype):
|
||||
assert dtype == torch.int64
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
assert opt_ctx
|
||||
assert opt_ctx.dtype == torch.int32
|
||||
assert opt_ctx.is_most_inner_loop_irrevelant
|
||||
return f"at::vec::Vectorized<int>(static_cast<int>({cexpr(V.kernel.rename_indexing(expr))}))"
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute, f"{mask} ? {var}() : {other_code}"
|
||||
)
|
||||
# `result` is explicitly added to the args for correct propagation
|
||||
# of relevant itervars and vectorization status.
|
||||
csevar.update_on_args("masked", (mask, body, other, result), {})
|
||||
return csevar
|
||||
|
||||
|
||||
class CppKernel(Kernel):
|
||||
|
|
@ -1242,7 +1307,9 @@ class CppKernel(Kernel):
|
|||
line = f"{var}[{cexpr_index(index)}]"
|
||||
if V.graph.get_dtype(name) in [torch.float16]:
|
||||
line = f"static_cast<float>({line})"
|
||||
return self.cse.generate(self.loads, line)
|
||||
csevar = self.cse.generate(self.loads, line)
|
||||
csevar.update_on_args("load", (name, index), {})
|
||||
return csevar
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
assert "buf" in name
|
||||
|
|
@ -1472,6 +1539,9 @@ class CppKernel(Kernel):
|
|||
self.reduction_suffix.splice(self.stores)
|
||||
(self.loads, self.compute, self.stores, self.cse) = prior
|
||||
|
||||
def create_cse_var(self, *args, **kwargs):
|
||||
return CppCSEVariable(*args, **kwargs)
|
||||
|
||||
|
||||
class CppVecKernel(CppKernel):
|
||||
overrides = CppVecOverrides # type: ignore[assignment]
|
||||
|
|
@ -1506,7 +1576,11 @@ class CppVecKernel(CppKernel):
|
|||
non_contiguous = (
|
||||
not is_broadcast
|
||||
and stride_at(tiling_var, index) != 1
|
||||
or "tmp" in f"{index}"
|
||||
or any(
|
||||
self.cse.varname_map[s.name].depends_on(tiling_var)
|
||||
for s in index.free_symbols
|
||||
if s.name.startswith("tmp")
|
||||
)
|
||||
)
|
||||
var_expr = (
|
||||
f"{var}[{cexpr_index(index)}]"
|
||||
|
|
@ -1515,13 +1589,9 @@ class CppVecKernel(CppKernel):
|
|||
)
|
||||
loadbuf = "tmpbuf" if non_contiguous else var_expr
|
||||
if is_broadcast:
|
||||
# should always be broadcast as float for vectorization since we always use float to compute
|
||||
if is_mask:
|
||||
loadbuf = f"flag_to_float_scalar({loadbuf})"
|
||||
if dtype in DTYPE_LOWP_FP:
|
||||
line = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({loadbuf})"
|
||||
else:
|
||||
line = f"at::vec::Vectorized<float>(static_cast<float>({loadbuf}))"
|
||||
csevar = super().load(name, index)
|
||||
csevar.dtype = dtype
|
||||
return csevar
|
||||
elif dtype in [torch.uint8] and opt_ctx.is_load_uint8_as_float:
|
||||
line = (
|
||||
f"masked_load({loadbuf}, {load_mask})"
|
||||
|
|
@ -1563,7 +1633,11 @@ class CppVecKernel(CppKernel):
|
|||
tmpbufdefine += f"tmpbuf[{inner}] = {rhs};"
|
||||
line = f"([&]() {{ {tmpbufdeclare} {tmpbufdefine} return {line}; }})()"
|
||||
|
||||
return self.cse.generate(self.loads, line)
|
||||
csevar = self.cse.generate(self.loads, line)
|
||||
csevar.update_on_args("load", (name, index), {})
|
||||
assert isinstance(csevar, CppCSEVariable)
|
||||
csevar.is_vec = True
|
||||
return csevar
|
||||
|
||||
def get_vec_store_line(self, value, var, index, dtype):
|
||||
"""
|
||||
|
|
@ -1572,6 +1646,11 @@ class CppVecKernel(CppKernel):
|
|||
:param var: buffer to store into.
|
||||
:index: index into the `var`.
|
||||
"""
|
||||
# when value's type is str (e.g., welford reduction), caller should make sure
|
||||
# it is a vector
|
||||
assert isinstance(value, str) or (
|
||||
isinstance(value, CppCSEVariable) and value.is_vec
|
||||
), value
|
||||
tiling_var = self.itervars[self.tiling_idx]
|
||||
assert index.has(tiling_var)
|
||||
var_expr = f"{var} + {cexpr_index(index)}"
|
||||
|
|
@ -1600,6 +1679,10 @@ class CppVecKernel(CppKernel):
|
|||
def store(self, name, index, value, mode=None):
|
||||
assert "buf" in name
|
||||
assert mode is None
|
||||
assert isinstance(value, CppCSEVariable), value
|
||||
if not value.is_vec:
|
||||
# this happens when we store a scalar into a vectorized buffer like "fill"
|
||||
value = self.broadcast(value)
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
var = self.args.output(name)
|
||||
index = self.rename_indexing(index)
|
||||
|
|
@ -1622,6 +1705,7 @@ class CppVecKernel(CppKernel):
|
|||
}
|
||||
assert dtype == torch.float
|
||||
assert src_dtype == torch.float
|
||||
assert isinstance(value, CppCSEVariable) and value.is_vec, value
|
||||
|
||||
vec_ns = "at::vec"
|
||||
vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
|
||||
|
|
@ -1740,6 +1824,27 @@ initializer(omp_priv={{{reduction_init_vec(reduction_type, dtype)}}})
|
|||
]
|
||||
self.reduction_suffix.writelines(store_lines)
|
||||
|
||||
def broadcast(self, scalar_var: CppCSEVariable):
|
||||
assert (
|
||||
not scalar_var.is_vec
|
||||
and self.itervars[self.tiling_idx] not in scalar_var.dependent_itervars
|
||||
)
|
||||
if scalar_var.dtype == torch.bool:
|
||||
vec_var = self.cse.generate(
|
||||
self.compute, f"to_float_mask({scalar_var.name})"
|
||||
)
|
||||
else:
|
||||
assert scalar_var.dtype is not None
|
||||
vec_var = self.cse.generate(
|
||||
self.compute,
|
||||
f"at::vec::Vectorized<{DTYPE_TO_CPP[scalar_var.dtype]}>({scalar_var.name})",
|
||||
)
|
||||
assert isinstance(vec_var, CppCSEVariable)
|
||||
vec_var.dtype = scalar_var.dtype
|
||||
vec_var.dependent_itervars = scalar_var.dependent_itervars
|
||||
vec_var.is_vec = True
|
||||
return vec_var
|
||||
|
||||
|
||||
class CppTile2DKernel(CppVecKernel):
|
||||
"""
|
||||
|
|
@ -1849,7 +1954,11 @@ class CppTile2DKernel(CppVecKernel):
|
|||
line = f"at::vec::Vectorized<uint8_t>::loadu_one_fourth({loadbuf})"
|
||||
else:
|
||||
line = f"at::vec::Vectorized<float>::loadu({loadbuf})"
|
||||
return self.cse.generate(self.loads, line)
|
||||
csevar = self.cse.generate(self.loads, line)
|
||||
csevar.update_on_args("load", (name, index), {})
|
||||
assert isinstance(csevar, CppCSEVariable)
|
||||
csevar.is_vec = True
|
||||
return csevar
|
||||
else:
|
||||
new_index = self.scale_index_with_offset(
|
||||
index,
|
||||
|
|
@ -1950,10 +2059,6 @@ class CppVecKernelChecker(CppVecKernel):
|
|||
schedule_log.debug("Disabled vectorization: %s", msg)
|
||||
self.simd_vec = False
|
||||
|
||||
def could_vec(self, name: str, index: sympy.Expr):
|
||||
assert self.itervars is not None
|
||||
return len(self.itervars) > 0
|
||||
|
||||
def is_mask(self, name: str, users: Dict[torch.fx.Node, None]):
|
||||
load_type = V.graph.get_dtype(name)
|
||||
if load_type == torch.bool:
|
||||
|
|
@ -2036,6 +2141,10 @@ class CppVecKernelChecker(CppVecKernel):
|
|||
|
||||
var = self.cse.newvar()
|
||||
|
||||
if len(self.itervars) == 0:
|
||||
self.disable_vec("not a loop")
|
||||
return var
|
||||
|
||||
if load_dtype in [torch.bool, torch.uint8] and not (
|
||||
opt_ctx.is_load_as_mask or opt_ctx.is_load_uint8_as_float
|
||||
):
|
||||
|
|
@ -2046,18 +2155,21 @@ class CppVecKernelChecker(CppVecKernel):
|
|||
return var
|
||||
|
||||
if (
|
||||
load_dtype not in self.load_supported_dtypes
|
||||
) and not self.is_load_integer_scalar_tensor(name, index):
|
||||
(load_dtype not in self.load_supported_dtypes)
|
||||
and not self.is_load_integer_scalar_tensor(name, index)
|
||||
and index.has(self.itervars[self.tiling_idx])
|
||||
):
|
||||
self.disable_vec(f"{load_dtype} not supported by load")
|
||||
return var
|
||||
|
||||
index = self.rename_indexing(index)
|
||||
if self.simd_vec and not self.could_vec(name, index):
|
||||
self.disable_vec(f"not a loop: {index}")
|
||||
return var
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
with RecordOptimizationContext(__name__) as node_ctx:
|
||||
if len(self.itervars) == 0:
|
||||
self.disable_vec("not a loop")
|
||||
return self.simd_vec
|
||||
|
||||
store_dtype = V.graph.get_dtype(name)
|
||||
|
||||
opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
|
||||
|
|
@ -2085,8 +2197,6 @@ class CppVecKernelChecker(CppVecKernel):
|
|||
|
||||
if index.is_number:
|
||||
self.disable_vec(f"constant store index: {index}")
|
||||
if self.simd_vec and not self.could_vec(name, index):
|
||||
self.disable_vec(f"not a loop: {index}")
|
||||
return self.simd_vec
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
|
|
|
|||
|
|
@ -401,4 +401,10 @@ template <>
|
|||
inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<float> src) {
|
||||
return src;
|
||||
}
|
||||
|
||||
inline at::vec::Vectorized<float> to_float_mask(int src) {
|
||||
float mask;
|
||||
*(uint32_t*)&mask = src ? 0xFFFFFFFF : 0;
|
||||
return at::vec::Vectorized<float>(mask);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user