[inductor cpp] vectorize embedding lookup (#114062)

For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.

This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.

For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0)
{
    #pragma omp parallel num_threads(64)
    {
        {
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
            {
                #pragma GCC ivdep
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
                {
                    auto tmp0 = in_ptr0[static_cast<long>(x0)];
                    auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
                    auto tmp1 = decltype(tmp0)(tmp0 + 64);
                    auto tmp2 = tmp0 < 0;
                    auto tmp3 = tmp2 ? tmp1 : tmp0;
                    TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
                    auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
                    auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
                    out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
                }
            }
        }
    }
}
```

After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0)
{
    #pragma omp parallel num_threads(64)
    {
        {
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = in_ptr0[static_cast<long>(x0)];
                    auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
                    auto tmp1 = decltype(tmp0)(tmp0 + 64);
                    auto tmp2 = tmp0 < 0;
                    auto tmp3 = tmp2 ? tmp1 : tmp0;
                    TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
                    auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
                    auto tmp6 = tmp4 + tmp5;
                    tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
                }
            }
        }
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114062
Approved by: https://github.com/jansel
This commit is contained in:
Jiong Gong 2023-11-22 16:00:38 +08:00 committed by PyTorch MergeBot
parent 3e1abde46d
commit a0e3321f0c
3 changed files with 225 additions and 91 deletions

View File

@ -1299,6 +1299,7 @@ class CPUReproTests(TestCase):
cpp_op_list.append(k)
diff = [
"constant",
"index_expr",
"signbit",
"isinf",
@ -2612,6 +2613,23 @@ class CPUReproTests(TestCase):
x = torch.randn(1, 39, 1, 18, 17)
self.common(m, (x,))
def test_embedding_vec(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(64, 128)
def forward(self, idx, x):
return self.emb(idx) + x
idx = torch.randint(0, 64, (4, 32))
x = torch.randn(4, 32, 128)
m = M().eval()
with torch.no_grad():
metrics.reset()
self.common(m, (idx, x))
assert metrics.generated_cpp_vec_kernel_count == 1
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -7,7 +7,7 @@ import math
import re
import sys
from copy import copy, deepcopy
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import sympy
@ -133,6 +133,19 @@ DTYPE_LOWP_FP = [
]
def value_to_cpp(value, cpp_type):
if value == float("-inf"):
return f"-std::numeric_limits<{cpp_type}>::infinity()"
elif value == float("inf"):
return f"std::numeric_limits<{cpp_type}>::infinity()"
elif isinstance(value, bool):
return f"static_cast<{cpp_type}>({str(value).lower()})"
elif math.isnan(value):
return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
else:
return f"static_cast<{cpp_type}>({repr(value)})"
def reduction_init(reduction_type, dtype):
if dtype in DTYPE_LOWP_FP:
# Since load promotes all half-precision inputs to float, the initial
@ -436,6 +449,54 @@ def get_current_node_opt_ctx() -> OptimizationContext:
return get_opt_ctx(V.interpreter.current_node)
class CppCSEVariable(CSEVariable):
def __init__(self, name, bounds: ValueRanges):
super().__init__(name, bounds)
self.is_vec = False
self.dtype: Optional[torch.dtype] = None
self.dependent_itervars: Set[sympy.Symbol] = set()
def update_on_args(self, name, args, kwargs):
if name == "load":
# args[1] is index
self._set_dependent_itervars(args[1])
else:
# propagate relevant itervars and is_vec from args
self.dependent_itervars.update(
*[
arg.dependent_itervars
for arg in args
if isinstance(arg, CppCSEVariable)
]
)
if name == "index_expr":
self._set_dependent_itervars(args[0])
if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
self.is_vec = True
if (
hasattr(V.interpreter, "current_node")
and get_current_node_opt_ctx() is not None
):
self.dtype = get_current_node_opt_ctx().dtype
def _set_dependent_itervars(self, index: sympy.Expr):
"""
Set the relevant itervars for this variable based on the `index` expression.
This includes the itervars directly used in the `index` as well as relevant itervars
of other cse variables used in the `index`.
"""
for s in index.free_symbols:
if s in V.kernel.itervars:
self.dependent_itervars.add(s)
elif s.name in V.kernel.cse.varname_map:
self.dependent_itervars.update(
V.kernel.cse.varname_map[s.name].dependent_itervars
)
def depends_on(self, itervar: sympy.Symbol):
return itervar in self.dependent_itervars
class CppOverrides(OpOverrides):
"""Map element-wise ops to C++"""
@ -672,22 +733,20 @@ class CppOverrides(OpOverrides):
@staticmethod
def constant(val, dtype):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx and opt_ctx.dtype is not None
dtype = opt_ctx.dtype
if dtype in DTYPE_LOWP_FP:
# Since load promotes all half-precision inputs to float, constants
# must be promoted as well
dtype = torch.float32
if val == float("inf"):
return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val == float("-inf"):
return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif math.isnan(val):
return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()"
elif val is True or val is False:
return ops.to_dtype(str(val).lower(), dtype)
return ops.to_dtype(repr(val), dtype)
return value_to_cpp(val, DTYPE_TO_CPP[dtype])
@staticmethod
def index_expr(expr, dtype):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx and opt_ctx.dtype is not None
dtype = opt_ctx.dtype
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
@staticmethod
@ -704,19 +763,7 @@ class CppOverrides(OpOverrides):
V.kernel.compute.splice(code)
# Use the lambda's return type as the type of other
type = f"decltype({body_var}())"
if other == float("-inf"):
other_code = f"-std::numeric_limits<{type}>::infinity()"
elif other == float("inf"):
other_code = f"std::numeric_limits<{type}>::infinity()"
elif isinstance(other, bool):
other_code = f"static_cast<{type}>({str(other).lower()})"
elif math.isnan(other):
other_code = f"std::numeric_limits<{type}>::quiet_NaN()"
else:
other_code = f"static_cast<{type}>({repr(other)})"
other_code = value_to_cpp(other, f"decltype({body_var}())")
return f"{mask} ? {body_var}() : {other_code}"
@staticmethod
@ -794,6 +841,54 @@ class CppOverrides(OpOverrides):
class CppVecOverrides(CppOverrides):
"""Map element-wise ops to aten vectorization C++"""
def __new__(cls, *args, **kargs):
self = super().__new__(cls)
def wrap(func):
# `CppVecKernel` generates both scalar ops and vector ops according to
# whether the inputs are scalars or vectors while all ops in `CppVecOverrides`
# (except for "masked") assume the inputs are vectors. We wrap the ops in
# `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to
# `CppOverrides` when all inputs are scalars.
#
# Inputs to ops.masked are handled separately in its own function due to
# the need of recurive handling of masked body.
def wrapper(*args, **kwargs):
has_scalar = any(
not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
)
has_vector = any(
arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
)
new_args = list(args)
if has_scalar and has_vector:
# broadcast scalar args to vector if needed
new_args = []
for arg in args:
if isinstance(arg, CppCSEVariable) and not arg.is_vec:
assert isinstance(V.kernel, CppVecKernel)
new_arg = V.kernel.broadcast(arg)
new_args.append(new_arg)
else:
new_args.append(arg)
if has_vector:
return func(*new_args, **kwargs)
else:
# fallback to scalar ops
scalar_ops = super(CppVecOverrides, self)
scalar_func = getattr(
scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined]
)
assert scalar_func is not None
return scalar_func(*args, **kwargs)
return wrapper
for name, method in vars(cls).items():
if getattr(method, "__class__", None) == staticmethod and name != "masked":
setattr(self, name, wrap(method.__func__))
return self
@staticmethod
def add(a, b):
return f"{a} + {b}"
@ -1006,28 +1101,6 @@ class CppVecOverrides(CppOverrides):
vec_one = f"decltype({x})(1)"
return f"({x} + ({x}*{x} - {vec_one}).sqrt()).log()"
@staticmethod
def constant(val, dtype):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx
proposed_dtype = opt_ctx.dtype
assert proposed_dtype in [
torch.float,
torch.int32,
]
if val == float("inf"):
quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
elif val == float("-inf"):
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
elif math.isnan(val):
quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::quiet_NaN()"
elif val is True or val is False:
quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({str(val).lower()})"
else:
quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({repr(val)})"
return f"at::vec::Vectorized<{DTYPE_TO_CPP[proposed_dtype]}>({quote})"
@staticmethod
def relu(x):
bug = config.cpp.inject_relu_bug_TESTING_ONLY
@ -1159,32 +1232,24 @@ class CppVecOverrides(CppOverrides):
code.writeline(";")
V.kernel.compute.splice(code)
if other == float("-inf"):
other_code = (
"at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())"
)
elif other == float("inf"):
other_code = (
"at::vec::Vectorized<float>(std::numeric_limits<float>::infinity())"
)
elif math.isnan(other):
other_code = (
"at::vec::Vectorized<float>(std::numeric_limits<float>::quiet_NaN())"
other_code = value_to_cpp(other, "float")
other_code_vec = f"at::vec::Vectorized<float>({other_code})"
if result.is_vec:
type = f"decltype({var}())"
float_mask = f"to_float_mask({new_mask})"
csevar = V.kernel.cse.generate(
V.kernel.compute,
f"{type}::blendv({other_code_vec}, {var}(), {float_mask})",
)
else:
other_code = f"at::vec::Vectorized<float>({other!r})"
type = f"decltype({var}())"
float_mask = f"to_float_mask({new_mask})"
return f"{type}::blendv({other_code}, {var}(), {float_mask})"
@staticmethod
def index_expr(expr, dtype):
assert dtype == torch.int64
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx
assert opt_ctx.dtype == torch.int32
assert opt_ctx.is_most_inner_loop_irrevelant
return f"at::vec::Vectorized<int>(static_cast<int>({cexpr(V.kernel.rename_indexing(expr))}))"
csevar = V.kernel.cse.generate(
V.kernel.compute, f"{mask} ? {var}() : {other_code}"
)
# `result` is explicitly added to the args for correct propagation
# of relevant itervars and vectorization status.
csevar.update_on_args("masked", (mask, body, other, result), {})
return csevar
class CppKernel(Kernel):
@ -1242,7 +1307,9 @@ class CppKernel(Kernel):
line = f"{var}[{cexpr_index(index)}]"
if V.graph.get_dtype(name) in [torch.float16]:
line = f"static_cast<float>({line})"
return self.cse.generate(self.loads, line)
csevar = self.cse.generate(self.loads, line)
csevar.update_on_args("load", (name, index), {})
return csevar
def store(self, name, index, value, mode=None):
assert "buf" in name
@ -1472,6 +1539,9 @@ class CppKernel(Kernel):
self.reduction_suffix.splice(self.stores)
(self.loads, self.compute, self.stores, self.cse) = prior
def create_cse_var(self, *args, **kwargs):
return CppCSEVariable(*args, **kwargs)
class CppVecKernel(CppKernel):
overrides = CppVecOverrides # type: ignore[assignment]
@ -1506,7 +1576,11 @@ class CppVecKernel(CppKernel):
non_contiguous = (
not is_broadcast
and stride_at(tiling_var, index) != 1
or "tmp" in f"{index}"
or any(
self.cse.varname_map[s.name].depends_on(tiling_var)
for s in index.free_symbols
if s.name.startswith("tmp")
)
)
var_expr = (
f"{var}[{cexpr_index(index)}]"
@ -1515,13 +1589,9 @@ class CppVecKernel(CppKernel):
)
loadbuf = "tmpbuf" if non_contiguous else var_expr
if is_broadcast:
# should always be broadcast as float for vectorization since we always use float to compute
if is_mask:
loadbuf = f"flag_to_float_scalar({loadbuf})"
if dtype in DTYPE_LOWP_FP:
line = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({loadbuf})"
else:
line = f"at::vec::Vectorized<float>(static_cast<float>({loadbuf}))"
csevar = super().load(name, index)
csevar.dtype = dtype
return csevar
elif dtype in [torch.uint8] and opt_ctx.is_load_uint8_as_float:
line = (
f"masked_load({loadbuf}, {load_mask})"
@ -1563,7 +1633,11 @@ class CppVecKernel(CppKernel):
tmpbufdefine += f"tmpbuf[{inner}] = {rhs};"
line = f"([&]() {{ {tmpbufdeclare} {tmpbufdefine} return {line}; }})()"
return self.cse.generate(self.loads, line)
csevar = self.cse.generate(self.loads, line)
csevar.update_on_args("load", (name, index), {})
assert isinstance(csevar, CppCSEVariable)
csevar.is_vec = True
return csevar
def get_vec_store_line(self, value, var, index, dtype):
"""
@ -1572,6 +1646,11 @@ class CppVecKernel(CppKernel):
:param var: buffer to store into.
:index: index into the `var`.
"""
# when value's type is str (e.g., welford reduction), caller should make sure
# it is a vector
assert isinstance(value, str) or (
isinstance(value, CppCSEVariable) and value.is_vec
), value
tiling_var = self.itervars[self.tiling_idx]
assert index.has(tiling_var)
var_expr = f"{var} + {cexpr_index(index)}"
@ -1600,6 +1679,10 @@ class CppVecKernel(CppKernel):
def store(self, name, index, value, mode=None):
assert "buf" in name
assert mode is None
assert isinstance(value, CppCSEVariable), value
if not value.is_vec:
# this happens when we store a scalar into a vectorized buffer like "fill"
value = self.broadcast(value)
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
var = self.args.output(name)
index = self.rename_indexing(index)
@ -1622,6 +1705,7 @@ class CppVecKernel(CppKernel):
}
assert dtype == torch.float
assert src_dtype == torch.float
assert isinstance(value, CppCSEVariable) and value.is_vec, value
vec_ns = "at::vec"
vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
@ -1740,6 +1824,27 @@ initializer(omp_priv={{{reduction_init_vec(reduction_type, dtype)}}})
]
self.reduction_suffix.writelines(store_lines)
def broadcast(self, scalar_var: CppCSEVariable):
assert (
not scalar_var.is_vec
and self.itervars[self.tiling_idx] not in scalar_var.dependent_itervars
)
if scalar_var.dtype == torch.bool:
vec_var = self.cse.generate(
self.compute, f"to_float_mask({scalar_var.name})"
)
else:
assert scalar_var.dtype is not None
vec_var = self.cse.generate(
self.compute,
f"at::vec::Vectorized<{DTYPE_TO_CPP[scalar_var.dtype]}>({scalar_var.name})",
)
assert isinstance(vec_var, CppCSEVariable)
vec_var.dtype = scalar_var.dtype
vec_var.dependent_itervars = scalar_var.dependent_itervars
vec_var.is_vec = True
return vec_var
class CppTile2DKernel(CppVecKernel):
"""
@ -1849,7 +1954,11 @@ class CppTile2DKernel(CppVecKernel):
line = f"at::vec::Vectorized<uint8_t>::loadu_one_fourth({loadbuf})"
else:
line = f"at::vec::Vectorized<float>::loadu({loadbuf})"
return self.cse.generate(self.loads, line)
csevar = self.cse.generate(self.loads, line)
csevar.update_on_args("load", (name, index), {})
assert isinstance(csevar, CppCSEVariable)
csevar.is_vec = True
return csevar
else:
new_index = self.scale_index_with_offset(
index,
@ -1950,10 +2059,6 @@ class CppVecKernelChecker(CppVecKernel):
schedule_log.debug("Disabled vectorization: %s", msg)
self.simd_vec = False
def could_vec(self, name: str, index: sympy.Expr):
assert self.itervars is not None
return len(self.itervars) > 0
def is_mask(self, name: str, users: Dict[torch.fx.Node, None]):
load_type = V.graph.get_dtype(name)
if load_type == torch.bool:
@ -2036,6 +2141,10 @@ class CppVecKernelChecker(CppVecKernel):
var = self.cse.newvar()
if len(self.itervars) == 0:
self.disable_vec("not a loop")
return var
if load_dtype in [torch.bool, torch.uint8] and not (
opt_ctx.is_load_as_mask or opt_ctx.is_load_uint8_as_float
):
@ -2046,18 +2155,21 @@ class CppVecKernelChecker(CppVecKernel):
return var
if (
load_dtype not in self.load_supported_dtypes
) and not self.is_load_integer_scalar_tensor(name, index):
(load_dtype not in self.load_supported_dtypes)
and not self.is_load_integer_scalar_tensor(name, index)
and index.has(self.itervars[self.tiling_idx])
):
self.disable_vec(f"{load_dtype} not supported by load")
return var
index = self.rename_indexing(index)
if self.simd_vec and not self.could_vec(name, index):
self.disable_vec(f"not a loop: {index}")
return var
def store(self, name, index, value, mode=None):
with RecordOptimizationContext(__name__) as node_ctx:
if len(self.itervars) == 0:
self.disable_vec("not a loop")
return self.simd_vec
store_dtype = V.graph.get_dtype(name)
opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
@ -2085,8 +2197,6 @@ class CppVecKernelChecker(CppVecKernel):
if index.is_number:
self.disable_vec(f"constant store index: {index}")
if self.simd_vec and not self.could_vec(name, index):
self.disable_vec(f"not a loop: {index}")
return self.simd_vec
def reduction(self, dtype, src_dtype, reduction_type, value):

View File

@ -401,4 +401,10 @@ template <>
inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<float> src) {
return src;
}
inline at::vec::Vectorized<float> to_float_mask(int src) {
float mask;
*(uint32_t*)&mask = src ? 0xFFFFFFFF : 0;
return at::vec::Vectorized<float>(mask);
}
#endif