[inductor][cpp] vectorization support for int32/int64 (#119001)

This pull request aims to complete most of the support for vectorizing int32 and int64 data types except for indirect indexing and masks. The basic data type support for uint32 and uint64 is also added but without vectorization. More vectorized conversion functions are added between integer and float. In order to support int64 vectors, a new VectorizedN class to handle vectors of arbitrary length. Below are the details:
1. Complete most of the int32 and int64 vectorization support including load, store, reduction, constant and conversion. The indirect indexing and masks will be addressed in follow-up PRs, after which, the legality checking logic in `CppVecKernelChecker` can be further simplified.
2. Util functions for conversion between integer and float vectors (in cpp_prefix.h and ATen vec). Ideally, we'd better move them from cpp_prefix.h to ATen vec to simplify cpp_prefix.h, will be addressed in follow-up PRs.
3. Introduced a new template class VectorizedN, designed to handle vectors of arbitrary length by encapsulating multiple Vectorized<T> instances. This class supports most of the operations of `Vectorized<T>`. It makes the support of int64 vectorization simpler. I will also apply it to bf16/fp16/int8 in the follow-up PRs for better efficiency. For example, bf16 currently only uses half of the vector lanes. With `VectorizedN`, we can use full of the lanes and map bf16 vector to `VectorizedN<float,2>` on conversion.
4. Basic data type support is added for uint32 and uint64 (in graph.py). Vectorization support will be added later but not of high priority due to fewer usages.

Next steps:

- [ ] Refactor the vector mask handling to support data types other than float. Currently vector masks are implemented with float vectors.
- [ ] Fully utilize vector lanes for bfloat16/float16/int8.
- [ ] Support indirect indexing with vectorized index via scalarization.
- [ ] Clean up `CppVecKernelChecker`.
- [ ] Simplify `cpp_prefix.h` including refactoring vector conversion logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119001
Approved by: https://github.com/peterbell10, https://github.com/jansel
This commit is contained in:
Jiong Gong 2024-02-08 22:28:10 +08:00 committed by PyTorch MergeBot
parent 8182fce769
commit 896cf9d1ce
8 changed files with 813 additions and 134 deletions

View File

@ -143,6 +143,24 @@ inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return _mm256_cvttps_epi32(src);
}
// Only works for inputs in the range: [-2^51, 2^51]
// From: https://stackoverflow.com/a/41148578
template<>
Vectorized<double>
inline convert_to_fp_of_same_size<double>(const Vectorized<int64_t> &src) {
auto x = _mm256_add_epi64(src, _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000)));
return _mm256_sub_pd(
_mm256_castsi256_pd(x),
_mm256_set1_pd(0x0018000000000000)
);
}
template<>
Vectorized<float>
inline convert_to_fp_of_same_size<float>(const Vectorized<int32_t> &src) {
return _mm256_cvtepi32_ps(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>

View File

@ -127,6 +127,18 @@ inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return _mm512_cvttps_epi32(src);
}
template<>
Vectorized<double>
inline convert_to_fp_of_same_size<double>(const Vectorized<int64_t> &src) {
return _mm512_cvtepi64_pd(src);
}
template<>
Vectorized<float>
inline convert_to_fp_of_same_size<float>(const Vectorized<int32_t> &src) {
return _mm512_cvtepi32_ps(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>

View File

@ -622,6 +622,12 @@ template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const
return c;
}
template <class T,
typename std::enable_if<!is_floating_point_v<T>, int>::type = 0>
Vectorized<T> inline operator%(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
return a - a / b * b;
}
template <class T> Vectorized<T> inline operator||(
const Vectorized<T> &a, const Vectorized<T> &b) {
Vectorized<T> c;
@ -989,6 +995,19 @@ inline Vectorized<IntType> convert_to_int_of_same_size(const Vectorized<T>& src)
return Vectorized<IntType>::loadu(static_cast<const void*>(buffer.data()));
}
template <typename T, typename IntType = int_same_size_t<T>>
inline Vectorized<T> convert_to_fp_of_same_size(const Vectorized<IntType>& src) {
static_assert(sizeof(T) == sizeof(IntType));
static constexpr int size = Vectorized<T>::size();
std::array<IntType, size> src_arr;
src.store(static_cast<void*>(src_arr.data()));
std::array<T, size> buffer;
std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
[](const IntType& x) { return static_cast<T>(x); });
return Vectorized<T>::loadu(static_cast<const void*>(buffer.data()));
}
// Example inputs for AVX512:
// a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}

View File

@ -0,0 +1,344 @@
#include <ATen/cpu/vec/vec_base.h>
#include <array>
namespace at::vec {
inline namespace CPU_CAPABILITY {
/**
* @brief A class template representing a vectorized type with
* `N * Vectorized<T>::size()` elements, aiming to support vectors of
* arbitrary size. A specific use case of it is to represent vectors
* converted from data types with different sizes but with the same
* number of vector elements, e.g., `VectorizedN<float, 2>` can be
* a vector converted from two `Vectorized<bfloat16>`, `VectorizedN<int64_t, 2>`
* can be a vector converted from two `Vectorized<int32_t>` etc.
*
* It supports most of the operations of `Vectorized<T>`
* and the implementation delegates to `Vectorized<T>` with loops over `N`.
*
* @tparam T The underlying type of the vectorized elements.
* @tparam N The number of underlying `Vectorized<T>`.
*/
template <typename T, int N>
class VectorizedN {
public:
using value_type = T;
using size_type = int;
static constexpr size_type size_T = sizeof(T);
static constexpr size_type size() {
return Vectorized<T>::size() * N;
}
private:
std::array<Vectorized<T>, N> values;
public:
// methods not implemented yet:
// variadic constructor, operator T*, as_bytes, zero_mask
#define VECTORIZEDN_DEFINE_UNARY_OP(op) \
VectorizedN<T, N> op() const { \
return unary_op([](const Vectorized<T>& a) { return a.op(); }); \
}
#define VECTORIZEDN_DEFINE_BINARY_OP(op) \
VectorizedN<T, N> op(const VectorizedN<T, N>& other) const { \
return binary_op( \
other, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
return a.op(b); \
}); \
}
template <typename Op>
inline VectorizedN<T, N> unary_op(Op op) const {
VectorizedN<T, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result.values[i] = op(values[i]);
}
return result;
}
template <typename Op>
inline VectorizedN<T, N> binary_op(const VectorizedN<T, N>& other, Op op)
const {
VectorizedN<T, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result.values[i] = op(values[i], other.values[i]);
}
return result;
}
VectorizedN() = default;
explicit VectorizedN(T val) {
for (int i = 0; i < N; ++i) {
values[i] = Vectorized<T>(val);
}
}
const Vectorized<T>& operator[](int i) const {
return values[i];
}
Vectorized<T>& operator[](int i) {
return values[i];
}
template <int64_t mask>
static VectorizedN<T, N> blend(
const VectorizedN<T, N>& a,
const VectorizedN<T, N>& b) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::blend<mask>(a.values[i], b.values[i]);
}
return result;
}
static VectorizedN<T, N> blendv(
const VectorizedN<T, N>& a,
const VectorizedN<T, N>& b,
const VectorizedN<T, N>& mask) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] =
Vectorized<T>::blendv(a.values[i], b.values[i], mask.values[i]);
}
return result;
}
template <typename step_t>
static VectorizedN<T, N> arange(
T base = static_cast<T>(0),
step_t step = static_cast<step_t>(1)) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::arange(base, step);
base += step * Vectorized<T>::size();
}
return result;
}
static VectorizedN<T, N> set(
const VectorizedN<T, N>& a,
const VectorizedN<T, N>& b,
int64_t count = size()) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] =
Vectorized<T>::set(a.values[i], b.values[i], std::min(count, Vectorized<T>::size()));
count -= Vectorized<T>::size();
if (count <= 0) {
break;
}
}
return result;
}
static VectorizedN<T, N> loadu(const void* ptr) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::loadu(ptr);
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
}
return result;
}
static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] =
Vectorized<T>::loadu(ptr, std::min(count, Vectorized<T>::size()));
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
count -= Vectorized<T>::size();
if (count <= 0) {
break;
}
}
return result;
}
void store(void* ptr) const {
for (int i = 0; i < N; ++i) {
values[i].store(ptr);
ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
}
}
void store(void* ptr, int count) const {
for (int i = 0; i < N; ++i) {
values[i].store(ptr, std::min(count, Vectorized<T>::size()));
ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
count -= Vectorized<T>::size();
if (count <= 0) {
break;
}
}
}
bool has_inf_nan() const {
for (int i = 0; i < N; ++i) {
if (values[i].has_inf_nan()) {
return true;
}
}
return false;
}
VectorizedN<T, N> map(T (*const f)(T)) const {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = values[i].map(f);
}
return result;
}
VectorizedN<T, N> map(T (*const f)(const T&)) const {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = values[i].map(f);
}
return result;
}
VECTORIZEDN_DEFINE_UNARY_OP(abs)
VECTORIZEDN_DEFINE_UNARY_OP(sgn)
VECTORIZEDN_DEFINE_UNARY_OP(angle)
VECTORIZEDN_DEFINE_UNARY_OP(real)
VECTORIZEDN_DEFINE_UNARY_OP(imag)
VECTORIZEDN_DEFINE_UNARY_OP(conj)
VECTORIZEDN_DEFINE_UNARY_OP(acos)
VECTORIZEDN_DEFINE_UNARY_OP(acosh)
VECTORIZEDN_DEFINE_UNARY_OP(asin)
VECTORIZEDN_DEFINE_UNARY_OP(atan)
VECTORIZEDN_DEFINE_UNARY_OP(atanh)
VECTORIZEDN_DEFINE_BINARY_OP(atan2)
VECTORIZEDN_DEFINE_BINARY_OP(copysign)
VECTORIZEDN_DEFINE_UNARY_OP(erf)
VECTORIZEDN_DEFINE_UNARY_OP(erfc)
VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
VECTORIZEDN_DEFINE_UNARY_OP(exp)
VECTORIZEDN_DEFINE_UNARY_OP(exp2)
VECTORIZEDN_DEFINE_UNARY_OP(expm1)
VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
VECTORIZEDN_DEFINE_UNARY_OP(frac)
VECTORIZEDN_DEFINE_BINARY_OP(fmod)
VECTORIZEDN_DEFINE_UNARY_OP(log)
VECTORIZEDN_DEFINE_UNARY_OP(log10)
VECTORIZEDN_DEFINE_UNARY_OP(log1p)
VECTORIZEDN_DEFINE_UNARY_OP(log2)
VECTORIZEDN_DEFINE_UNARY_OP(ceil)
VECTORIZEDN_DEFINE_UNARY_OP(cos)
VECTORIZEDN_DEFINE_UNARY_OP(cosh)
VECTORIZEDN_DEFINE_UNARY_OP(floor)
VECTORIZEDN_DEFINE_BINARY_OP(hypot)
VECTORIZEDN_DEFINE_UNARY_OP(i0)
VECTORIZEDN_DEFINE_UNARY_OP(i0e)
VECTORIZEDN_DEFINE_UNARY_OP(digamma)
VECTORIZEDN_DEFINE_BINARY_OP(igamma)
VECTORIZEDN_DEFINE_BINARY_OP(igammac)
VECTORIZEDN_DEFINE_UNARY_OP(neg)
VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
VECTORIZEDN_DEFINE_UNARY_OP(round)
VECTORIZEDN_DEFINE_UNARY_OP(sin)
VECTORIZEDN_DEFINE_UNARY_OP(sinh)
VECTORIZEDN_DEFINE_UNARY_OP(tan)
VECTORIZEDN_DEFINE_UNARY_OP(tanh)
VECTORIZEDN_DEFINE_UNARY_OP(trunc)
VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
VECTORIZEDN_DEFINE_BINARY_OP(pow)
VECTORIZEDN_DEFINE_BINARY_OP(operator==)
VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
VECTORIZEDN_DEFINE_BINARY_OP(operator>)
VECTORIZEDN_DEFINE_BINARY_OP(operator<)
VECTORIZEDN_DEFINE_BINARY_OP(eq)
VECTORIZEDN_DEFINE_BINARY_OP(ne)
VECTORIZEDN_DEFINE_BINARY_OP(gt)
VECTORIZEDN_DEFINE_BINARY_OP(ge)
VECTORIZEDN_DEFINE_BINARY_OP(lt)
VECTORIZEDN_DEFINE_BINARY_OP(le)
#undef VECTORIZEDN_DEFINE_UNARY_OP
#undef VECTORIZEDN_DEFINE_BINARY_OP
};
#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \
template <typename T, int N> \
inline VectorizedN<T, N> op(const VectorizedN<T, N>& a) { \
return a.unary_op([](const Vectorized<T>& a) { return op(a); }); \
}
#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \
template <typename T, int N> \
inline VectorizedN<T, N> op( \
const VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
return a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
return op(a, b); \
}); \
}
#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \
template <typename T, int N> \
inline VectorizedN<T, N>& op( \
VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
a = a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
return op(a, b); \
}); \
return a; \
}
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
template <typename T, int N, typename OpVec>
inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
Vectorized<T> vec_result = acc_vec[0];
for (int i = 1; i < N; i++) {
vec_result = vec_fun(vec_result, acc_vec[i]);
}
return vec_reduce_all(vec_fun, vec_result);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1831,14 +1831,14 @@ class CPUReproTests(TestCase):
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min
)
self.assertFalse(vec_checker.simd_vec)
self.assertTrue(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max
)
self.assertFalse(vec_checker.simd_vec)
self.assertTrue(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
@ -2772,8 +2772,7 @@ class CPUReproTests(TestCase):
with torch.no_grad():
metrics.reset()
self.common(fn, (x, y, mode))
# TODO: support vectorization for int div
assert metrics.generated_cpp_vec_kernel_count == 0
assert metrics.generated_cpp_vec_kernel_count == 1
def test_uint8_add(self):
# https://github.com/pytorch/pytorch/issues/113016
@ -2960,6 +2959,118 @@ class CPUReproTests(TestCase):
self.common(fn, (100, y))
assert metrics.generated_cpp_vec_kernel_count == 2
def test_int32_pointwise_vec(self):
def fn(x):
return x * x
x = torch.randint(0, 100, (32, 32), dtype=torch.int32)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_int32_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
x = torch.randint(0, 100, (32, 32), dtype=torch.int32)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_uint32_pointwise_vec(self):
def fn(x):
return x * x
x = torch.randint(0, 100, (32, 32), dtype=torch.uint32)
metrics.reset()
self.common(fn, (x,))
# TODO(jgong5): change to 1 with vectorized uint32 load
assert metrics.generated_cpp_vec_kernel_count == 0
def test_uint32_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
x = torch.randint(0, 100, (32, 32), dtype=torch.uint32)
metrics.reset()
self.common(fn, (x,))
# TODO(jgong5): change to 1 with vectorized uint32/uint64 load
assert metrics.generated_cpp_vec_kernel_count == 0
def test_int64_pointwise_vec(self):
def fn(x):
return x * x
x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_int64_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_uint64_pointwise_vec(self):
def fn(x):
return x * x
x = torch.randint(0, 100, (32, 32), dtype=torch.uint64)
metrics.reset()
self.common(fn, (x,))
# TODO(jgong5): change to 1 with vectorized uint64 load
assert metrics.generated_cpp_vec_kernel_count == 0
def test_uint64_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
x = torch.randint(0, 100, (32, 32), dtype=torch.uint64)
metrics.reset()
self.common(fn, (x,))
# TODO(jgong5): change to 1 with vectorized uint64 load
assert metrics.generated_cpp_vec_kernel_count == 0
def test_convert_int32_to_int64_vec(self):
def fn(x):
return x.to(torch.int64)
x = torch.randint(0, 100, (32, 32), dtype=torch.int32)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_convert_int64_to_int32_vec(self):
def fn(x):
return x.to(torch.int32)
x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_convert_fp32_to_int64_vec(self):
def fn(x):
return x.to(torch.int64)
x = torch.rand(32, 32)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_convert_int64_to_fp32_vec(self):
def fn(x):
return x.to(torch.float32)
x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -64,6 +64,8 @@ DTYPE_TO_CPP = {
torch.uint32: "unsigned int",
torch.uint16: "unsigned short",
torch.uint8: "unsigned char",
torch.uint32: "unsigned int",
torch.uint64: "unsigned long",
torch.bool: "bool",
torch.bfloat16: "bfloat16",
torch.complex64: "complex64",
@ -83,6 +85,8 @@ DTYPE_TO_ATEN = {
torch.uint32: "at::kUInt32",
torch.uint16: "at::kUInt16",
torch.uint8: "at::kByte",
torch.uint32: "at::kUInt32",
torch.uint64: "at::kUInt64",
torch.bool: "at::kBool",
torch.bfloat16: "at::kBFloat16",
torch.complex32: "at::kComplexHalf",
@ -187,17 +191,6 @@ def reduction_init(reduction_type, dtype):
raise AssertionError(reduction_type)
def reduction_init_vec(reduction_type, dtype):
scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]]
vec_type = f"at::vec::Vectorized<{scalar_type}>"
if is_welford_reduction(reduction_type):
return f"Welford<{vec_type}>()"
scalar_init = reduction_init(reduction_type, dtype)
return f"{vec_type}({scalar_init})"
def reduction_acc_type(reduction_type, dtype):
assert reduction_type not in {"argmin", "argmax"}
scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]]
@ -207,16 +200,6 @@ def reduction_acc_type(reduction_type, dtype):
return scalar_type
def reduction_acc_type_vec(reduction_type, dtype):
assert reduction_type not in {"argmin", "argmax"}
scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]]
vec_type = f"at::vec::Vectorized<{scalar_type}>"
if is_welford_reduction(reduction_type):
return f"Welford<{vec_type}>"
return vec_type
def reduction_combine(reduction_type, var, next_value):
if reduction_type == "sum":
return f"{var} + {next_value}"
@ -239,31 +222,6 @@ def reduction_combine(reduction_type, var, next_value):
raise AssertionError(reduction_type)
def reduction_combine_vec(reduction_type, var, next_value):
if reduction_type == "max":
return f"at::vec::maximum({var}, {next_value})"
elif reduction_type == "min":
return f"at::vec::minimum({var}, {next_value})"
elif reduction_type == "sum":
return f"{var} + {next_value}"
elif reduction_type == "prod":
return f"{var} * {next_value}"
elif reduction_type == "xor_sum":
return f"{var} ^ {next_value}"
elif reduction_type == "welford_reduce":
return f"welford_combine({var}, {next_value})"
elif reduction_type == "welford_combine":
if isinstance(next_value, tuple):
# When reading a value from Inductor IR we have a tuple of variable names
mean, m2, weight = next_value
else:
# When combining intermediate accumulators we have a Welford<T> struct
mean, m2, weight = reduction_project(reduction_type, next_value)
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
else:
raise NotImplementedError()
def reduction_project(reduction_type, acc):
if is_welford_reduction(reduction_type):
return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight"
@ -567,6 +525,10 @@ def get_current_node_opt_ctx() -> OptimizationContext:
return get_opt_ctx(V.interpreter.current_node)
class CppVecUnsupportedError(Exception):
pass
class CppCSEVariable(CSEVariable):
def __init__(self, name, bounds: ValueRanges[Any]):
super().__init__(name, bounds)
@ -1003,24 +965,37 @@ class CppVecOverrides(CppOverrides):
# needs to further analyze the dependency of the index expression on
# the tiling itervar.
def wrapper(*args, **kwargs):
has_scalar = any(
not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
)
has_vector = any(
arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)
)
scalars = [
arg
for arg in args
if isinstance(arg, CppCSEVariable) and not arg.is_vec
]
vectors = [
arg
for arg in args
if isinstance(arg, CppCSEVariable) and arg.is_vec
]
new_args = list(args)
if has_scalar and has_vector:
if scalars and vectors:
# broadcast scalar args to vector if needed
new_args = []
vec_dtype = vectors[0].dtype
for arg in args:
if isinstance(arg, CppCSEVariable) and not arg.is_vec:
assert isinstance(V.kernel, CppVecKernel)
# align scalar data type to the vector for binary ops
if len(args) == 2 and arg.dtype != vec_dtype:
arg = ops.to_dtype(arg, vec_dtype)
arg = arg.value if isinstance(arg, OpsValue) else arg
# See NOTE [dtype of CppCSEVariable]: we have to fix arg.dtype since
# the dtype from optimization context could be wrong.
assert isinstance(arg, CppCSEVariable)
arg.dtype = vec_dtype
new_arg = V.kernel.broadcast(arg)
new_args.append(new_arg)
else:
new_args.append(arg)
if has_vector:
if vectors:
return func(*new_args, **kwargs)
else:
# fallback to scalar ops
@ -1281,8 +1256,9 @@ class CppVecOverrides(CppOverrides):
# a and b are integer type
_t = f"decltype({a})"
quot = f"{a} / {b}"
rem = f"{a} % {b}"
return f"(({a} < {_t}(0)) != ({b} < {_t}(0)) ? ({rem} != {_t}(0) ? {quot} - {_t}(1) : {quot}) : {quot})"
has_rem = f"({a} % {b} != {_t}(0))"
is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))"
return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})"
@staticmethod
def truncdiv(a, b):
@ -1303,6 +1279,11 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def where(a, b, c):
assert isinstance(b, CppCSEVariable)
if b.dtype != torch.float:
raise CppVecUnsupportedError(
"where with non-float tensor is not supported in vectorized codegen"
)
return f"decltype({b})::blendv({c}, {b}, {a})"
@staticmethod
@ -1333,6 +1314,7 @@ class CppVecOverrides(CppOverrides):
torch.float16,
torch.uint8,
torch.int32,
torch.int64,
], f"{__name__} does not support {dtype}"
node: torch.fx.Node = V.interpreter.current_node
assert node and isinstance(node, torch.fx.Node)
@ -1344,6 +1326,8 @@ class CppVecOverrides(CppOverrides):
return f"mask_convert_to_float({x})"
if opt_ctx_x.dtype == torch.bool and dtype in DTYPE_LOWP_FP:
return f"mask_convert_to_lowp<{DTYPE_TO_CPP[dtype]}>({x})"
if opt_ctx_x.dtype == torch.bool and dtype == torch.int64:
return f"mask_convert_to_int64({x})"
if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype in DTYPE_LOWP_FP:
return f"cvt_fp32_to_lowp_fp<{DTYPE_TO_CPP[dtype]}>({x})"
if opt_ctx_x.dtype in DTYPE_LOWP_FP and dtype in (torch.float, torch.float32):
@ -1357,6 +1341,18 @@ class CppVecOverrides(CppOverrides):
# * Pattern match of quantization op in the loop body.
# * Skip the explicit saturation and clamp inside at::vec::convert_float_to_uint8.
return f"at::vec::convert_float_to_uint8({x})"
if opt_ctx_x.dtype == torch.int32 and dtype == torch.float:
return f"at::vec::convert_to_fp_of_same_size<float>({x})"
if opt_ctx_x.dtype == torch.float and dtype == torch.int32:
return f"at::vec::convert_to_int_of_same_size({x})"
if opt_ctx_x.dtype == torch.int64 and dtype == torch.float:
return f"cvt_int64_to_fp32({x})"
if opt_ctx_x.dtype == torch.float and dtype == torch.int64:
return f"cvt_fp32_to_int64({x})"
if opt_ctx_x.dtype == torch.int32 and dtype == torch.int64:
return f"cvt_int32_to_int64({x})"
if opt_ctx_x.dtype == torch.int64 and dtype == torch.int32:
return f"cvt_int64_to_int32({x})"
# TODO(jgong5): support conversion for other types
# currently we only allow load/store torch.uint8 and handle conversion there
return f"({x})"
@ -1375,6 +1371,7 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def masked(mask, body, other):
assert isinstance(V.kernel, CppVecKernel)
code = BracesBuffer()
var = V.kernel.cse.newvar()
with V.kernel.masked(mask) as new_mask:
@ -1387,10 +1384,12 @@ class CppVecOverrides(CppOverrides):
body_code = f"{var}()"
body_code_vec = (
body_code if result.is_vec else f"at::vec::Vectorized<float>({body_code})"
body_code
if result.is_vec
else f"{V.kernel._get_vec_type(torch.float)}({body_code})"
)
other_code = value_to_cpp(other, "float")
other_code_vec = f"at::vec::Vectorized<float>({other_code})"
other_code_vec = f"{V.kernel._get_vec_type(torch.float)}({other_code})"
assert isinstance(new_mask, CppCSEVariable), new_mask
if new_mask.is_vec or result.is_vec:
type = f"decltype({body_code_vec})"
@ -1782,12 +1781,26 @@ class CppVecKernel(CppKernel):
tiling_dtype=torch.float,
):
super().__init__(args, num_threads)
assert codecache.pick_vec_isa()
self.vec_isa = codecache.pick_vec_isa()
assert self.vec_isa
if tiling_factor == 0:
tiling_factor = codecache.pick_vec_isa().nelements(dtype=tiling_dtype)
tiling_factor = self.vec_isa.nelements(dtype=tiling_dtype)
self.tiling_factor = tiling_factor
self.tiling_idx = tiling_idx
metrics.generated_cpp_vec_kernel_count += 1
def _get_num_vectors(self, dtype: torch.dtype) -> int:
num_vectors = math.ceil(
self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
)
assert num_vectors >= 1
return num_vectors
def _get_vec_type(self, dtype: torch.dtype) -> str:
num_vectors = self._get_num_vectors(dtype)
if num_vectors == 1:
return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>"
else:
return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
def _get_vec_load_line(
self,
@ -1810,6 +1823,7 @@ class CppVecKernel(CppKernel):
load_mask_str = f"to_float_mask({load_mask})" if load_mask else None
loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var
if dtype == torch.uint8 and opt_ctx.is_load_uint8_as_float:
assert self._get_num_vectors(torch.uint8) == 1
line = (
f"masked_load({loadbuf}, {load_mask_str})"
if load_mask_str
@ -1821,13 +1835,13 @@ class CppVecKernel(CppKernel):
line = (
f"masked_load({loadbuf}, {load_mask_str})"
if load_mask_str
else f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu({loadbuf}, {self.tiling_factor})"
else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.tiling_factor})"
)
else:
line = (
f"masked_load({loadbuf}, {load_mask_str})"
if load_mask_str
else f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu({loadbuf})"
else f"{self._get_vec_type(dtype)}::loadu({loadbuf})"
)
return line
@ -1856,8 +1870,10 @@ class CppVecKernel(CppKernel):
buffer = self.loads
def get_result_size(dtype: torch.dtype) -> int:
assert dtype.itemsize <= 4
return self.tiling_factor * (4 // dtype.itemsize)
if dtype.itemsize < 4:
return self.tiling_factor * (4 // dtype.itemsize)
else:
return self.tiling_factor
def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert vec_var.is_vec
@ -1987,7 +2003,7 @@ class CppVecKernel(CppKernel):
isinstance(value, CppCSEVariable) and value.is_vec
), value
tiling_var = self.itervars[self.tiling_idx]
assert index.has(tiling_var)
assert index.has(tiling_var), f"index: {index}, tiling_var: {tiling_var}"
var_expr = f"{var} + {cexpr_index(index)}"
stride = stride_at_vec_range(index, tiling_var, self.tiling_factor)
non_contiguous = stride != 1 or self.index_indirect_depends_on(
@ -2041,17 +2057,15 @@ class CppVecKernel(CppKernel):
"welford_reduce",
"welford_combine",
}
assert dtype == torch.float
assert src_dtype == torch.float
assert dtype == src_dtype
assert dtype in [torch.float, torch.int64]
assert isinstance(value, CppCSEVariable), value
if not value.is_vec:
value = self.broadcast(value)
vec_ns = "at::vec"
vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
acc_type = reduction_acc_type(reduction_type, dtype)
acc_type_vec = reduction_acc_type_vec(reduction_type, dtype)
acc_type_vec = self.reduction_acc_type_vec(reduction_type, dtype)
if (reduction_type, acc_type) not in self.reduction_omp_dec:
if RTYPE_TO_CPP[reduction_type] not in NATIVE_OMP_RTYPES:
@ -2073,8 +2087,8 @@ initializer(omp_priv={{{reduction_init(reduction_type, dtype)}}})
f"""\
#pragma omp declare reduction(\
{RTYPE_TO_CPP[reduction_type]}:{acc_type_vec}:\
omp_out = {reduction_combine_vec(reduction_type, "omp_out", "omp_in")}) \
initializer(omp_priv={{{reduction_init_vec(reduction_type, dtype)}}})
omp_out = {self.reduction_combine_vec(reduction_type, "omp_out", "omp_in")}) \
initializer(omp_priv={{{self.reduction_init_vec(reduction_type, dtype)}}})
"""
)
self.reduction_omp_dec[reduction_type, acc_type_vec] = RTYPE_TO_CPP[
@ -2095,24 +2109,28 @@ initializer(omp_priv={{{reduction_init_vec(reduction_type, dtype)}}})
f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};"
)
self.reduction_prefix.writeline(
f"{acc_type_vec} {acc_vec} = {reduction_init_vec(reduction_type, dtype)};"
f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};"
)
self.stores.writeline(
f"{acc_vec} = {reduction_combine_vec(reduction_type, acc_vec, value)};"
f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};"
)
tmpvar: Union[str, CSEVariable]
if self.tiling_idx >= self.reduction_depth:
# Horizontal reduction
if is_welford_reduction(reduction_type):
assert (
self._get_num_vectors(dtype) == 1
), "Welford reduction does not support VectorizedN (N>1)"
next_value = f"welford_vec_reduce_all({acc_vec})"
else:
reduce_all_body = (
"{ return "
+ reduction_combine_vec(reduction_type, "x", "y")
+ self.reduction_combine_vec(reduction_type, "x", "y")
+ "; }"
)
vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>"
vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})"
self.reduction_suffix.writeline(
@ -2175,7 +2193,7 @@ initializer(omp_priv={{{reduction_init_vec(reduction_type, dtype)}}})
assert scalar_var.dtype is not None
vec_var = self.cse.generate(
self.compute,
f"at::vec::Vectorized<{DTYPE_TO_CPP[scalar_var.dtype]}>({scalar_var.name})",
f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})",
)
assert isinstance(vec_var, CppCSEVariable)
vec_var.dtype = scalar_var.dtype
@ -2192,13 +2210,57 @@ initializer(omp_priv={{{reduction_init_vec(reduction_type, dtype)}}})
assert isinstance(index, CppCSEVariable)
assert not index.is_vec
csevar = self.cse.generate(
self.compute, f"at::vec::Vectorized<int32_t>::arange({index}, {stride})"
self.compute,
f"{self._get_vec_type(torch.int32)}::arange({index}, {stride})",
)
assert isinstance(csevar, CppCSEVariable)
csevar.dtype = torch.int32
csevar.is_vec = True
return csevar
def reduction_init_vec(self, reduction_type, dtype):
scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
vec_type = self._get_vec_type(scalar_type)
if is_welford_reduction(reduction_type):
return f"Welford<{vec_type}>()"
scalar_init = reduction_init(reduction_type, dtype)
return f"{vec_type}({scalar_init})"
def reduction_acc_type_vec(self, reduction_type, dtype):
assert reduction_type not in {"argmin", "argmax"}
scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
vec_type = self._get_vec_type(scalar_type)
if is_welford_reduction(reduction_type):
return f"Welford<{vec_type}>"
return vec_type
def reduction_combine_vec(self, reduction_type, var, next_value):
if reduction_type == "max":
return f"at::vec::maximum({var}, {next_value})"
elif reduction_type == "min":
return f"at::vec::minimum({var}, {next_value})"
elif reduction_type == "sum":
return f"{var} + {next_value}"
elif reduction_type == "prod":
return f"{var} * {next_value}"
elif reduction_type == "xor_sum":
return f"{var} ^ {next_value}"
elif reduction_type == "welford_reduce":
return f"welford_combine({var}, {next_value})"
elif reduction_type == "welford_combine":
if isinstance(next_value, tuple):
# When reading a value from Inductor IR we have a tuple of variable names
mean, m2, weight = next_value
else:
# When combining intermediate accumulators we have a Welford<T> struct
mean, m2, weight = reduction_project(reduction_type, next_value)
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
else:
raise NotImplementedError()
class CppTile2DKernel(CppVecKernel):
"""
@ -2368,7 +2430,6 @@ class CppVecKernelChecker(CppVecKernel):
# Since this kernel is only for checker but does not generate any
# code, so we need to decrease the kernel count.
metrics.generated_kernel_count -= 1
metrics.generated_cpp_vec_kernel_count -= 1
# Used to record the graph wrapper code as the wrapper_code status could be
# changed during graph run.
@ -2389,12 +2450,16 @@ class CppVecKernelChecker(CppVecKernel):
torch.float16,
torch.bool,
torch.uint8,
torch.int32,
torch.int64,
]
self.store_supported_dtypes: List[torch.dtype] = [
torch.float,
torch.bfloat16,
torch.float16,
torch.uint8,
torch.int32,
torch.int64,
]
# Cache the dtypes of the store operation. If the store is mixing dtypes, the
# vectorization would not support it as it is hard to determine the vec dtype
@ -2549,8 +2614,8 @@ class CppVecKernelChecker(CppVecKernel):
def reduction(self, dtype, src_dtype, reduction_type, value):
if (
dtype == torch.float
and src_dtype == torch.float
(dtype == torch.float and src_dtype == torch.float)
or (dtype == torch.int64 and src_dtype == torch.int64)
and reduction_type in VECTORIZABLE_RTYPES
):
pass
@ -2673,6 +2738,7 @@ class CppVecKernelChecker(CppVecKernel):
supported_dtypes = [
torch.float32,
torch.int32,
torch.int64,
torch.bfloat16,
torch.float16,
torch.bool,
@ -2784,21 +2850,11 @@ class CppVecKernelChecker(CppVecKernel):
torch.bfloat16,
torch.float,
torch.uint8,
torch.int32,
torch.int64,
]:
# Convert from dtype to torch.float
pass
elif (
dtype in [torch.int32, torch.int64]
and input_value.target == "load"
):
buffer = V.graph.get_buffer(input_value.args[1]) # type: ignore[arg-type]
# Check if load of a scalar tensor of integer
if not (
isinstance(buffer, TensorBox)
and isinstance(buffer.data, StorageBox)
and len(buffer.data.layout.size) == 0
):
self.disable_vec(f"to_dtype: dtype {dtype}")
else:
self.disable_vec(f"to_dtype: dtype {dtype}")
elif dtype in DTYPE_LOWP_FP:
@ -2836,6 +2892,8 @@ class CppVecKernelChecker(CppVecKernel):
)
if not (is_to_uint8_and_store or is_to_uint8_and_to_float):
self.disable_vec(f"to_dtype: dtype {dtype}")
elif dtype in [torch.int64, torch.int32]:
pass
else:
self.disable_vec(f"to_dtype: dtype {dtype}")
return x
@ -3115,12 +3173,11 @@ class CppKernelProxy(CppKernel):
def codegen_kernel(cls, *args):
with kernel_group.new_kernel(cls, *args) as kernel:
run(kernel)
# Ugly hack to maintain the metrics kernel count since
# we only count in CppKernelProxy, not those contained in it
metrics.generated_kernel_count -= 1
run(kernel)
return kernel
def run(kernel):
@ -3225,46 +3282,52 @@ class CppKernelProxy(CppKernel):
with torch._inductor.config.patch(inplace_buffers=False):
tiling_factors, tiling_indices = select_tiling(vec_dtype)
assert len(tiling_factors) == len(tiling_indices)
if len(tiling_indices) == 1:
main_loop, tail_loop = self.loop_nest.split_with_tiling(
tiling_indices[0], factor=tiling_factors[0]
)
main_loop.set_kernel(
codegen_kernel(
try:
if len(tiling_indices) == 1:
vec_kernel = codegen_kernel(
CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype
)
)
tail_loop.set_kernel(scalar_kernel)
main_loop.simd_vec = True
tail_loop.simd_omp = True
# We chop the loop into two cubes by the nelements - main loop and tail loop.
# Regarding the main loop, it is straightforward that it could be vectorized with
# nelements. But for the tail loop, it still could be vectorized. For example,
# if the nelements is 8(256bits), then the tail loop still could be vectorized
# as 4(128bits).
tail_loop.simd_nelements = tiling_factors[0] // 2
elif len(tiling_indices) == 2:
assert (
tiling_indices[1] == len(self.itervars) - 1
and tiling_factors[0] == tiling_factors[1]
)
outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling(
tiling_indices[0], factor=tiling_factors[0]
)
outer_tail_loop.set_kernel(scalar_kernel)
inner_main_loop, inner_tail_loop = outer_main_loop.split_with_tiling(
tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
)
inner_main_loop.set_kernel(
codegen_kernel(
metrics.generated_cpp_vec_kernel_count += 1
main_loop, tail_loop = self.loop_nest.split_with_tiling(
tiling_indices[0], factor=tiling_factors[0]
)
main_loop.set_kernel(vec_kernel)
tail_loop.set_kernel(scalar_kernel)
main_loop.simd_vec = True
tail_loop.simd_omp = True
# We chop the loop into two cubes by the nelements - main loop and tail loop.
# Regarding the main loop, it is straightforward that it could be vectorized with
# nelements. But for the tail loop, it still could be vectorized. For example,
# if the nelements is 8(256bits), then the tail loop still could be vectorized
# as 4(128bits).
tail_loop.simd_nelements = tiling_factors[0] // 2
elif len(tiling_indices) == 2:
assert (
tiling_indices[1] == len(self.itervars) - 1
and tiling_factors[0] == tiling_factors[1]
)
tile2d_kernel = codegen_kernel(
CppTile2DKernel, tiling_factors[0], tiling_indices, vec_dtype
)
)
inner_tail_loop.set_kernel(
codegen_kernel(
vec_kernel = codegen_kernel(
CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype
)
)
metrics.generated_cpp_vec_kernel_count += 2
outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling(
tiling_indices[0], factor=tiling_factors[0]
)
outer_tail_loop.set_kernel(scalar_kernel)
(
inner_main_loop,
inner_tail_loop,
) = outer_main_loop.split_with_tiling(
tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
)
inner_main_loop.set_kernel(tile2d_kernel)
inner_tail_loop.set_kernel(vec_kernel)
except CppVecUnsupportedError as e:
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Disabled vectorization: %s", e)
def codegen_loops(self, code, worksharing):
self.codegen_loops_impl(self.loop_nest, code, worksharing)

View File

@ -28,6 +28,7 @@
#if INDUCTOR_USE_VECTOR_TYPES()
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/vec_n.h>
#endif
typedef at::Half half;
@ -477,4 +478,113 @@ inline bool vector_lane_mask_check(at::vec::Vectorized<float> src, int lane) {
# endif
}
inline at::vec::Vectorized<float> cvt_int64_to_fp32(at::vec::VectorizedN<int64_t,2> src) {
# if defined(CPU_CAPABILITY_AVX512)
auto low = _mm512_cvtepi64_ps(src[0]);
auto high = _mm512_cvtepi64_ps(src[1]);
return _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1);
# elif defined(CPU_CAPABILITY_AVX2)
auto low_double = at::vec::convert_to_fp_of_same_size<double>(src[0]);
auto low = _mm256_cvtpd_ps(low_double);
auto high_double = at::vec::convert_to_fp_of_same_size<double>(src[1]);
auto high = _mm256_cvtpd_ps(high_double);
return _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1);
# else
constexpr int float_vec_size = at::vec::Vectorized<float>::size();
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
__at_align__ float result[float_vec_size];
__at_align__ int64_t src_buf[int64_vec_size];
for (int i = 0; i < 2; i++) {
src[i].store(src_buf + i * int64_vec_size);
for (int j = 0; j < int64_vec_size; j++) {
result[i * int64_vec_size + j] = static_cast<float>(src_buf[i * int64_vec_size + j]);
}
}
return at::vec::Vectorized<float>::loadu(result);
# endif
}
inline at::vec::VectorizedN<int64_t,2> cvt_fp32_to_int64(at::vec::Vectorized<float> src) {
at::vec::VectorizedN<int64_t,2> result;
# if defined(CPU_CAPABILITY_AVX512)
result[0] = _mm512_cvt_roundps_epi64(_mm512_castps512_ps256(src), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
result[1] = _mm512_cvt_roundps_epi64(_mm512_extractf32x8_ps(src, 1), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
# elif defined(CPU_CAPABILITY_AVX2)
auto int32_vec = at::vec::convert_to_int_of_same_size(src);
result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(int32_vec));
result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(int32_vec, 1));
# else
constexpr int float_vec_size = at::vec::Vectorized<float>::size();
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
__at_align__ float src_buf[float_vec_size];
__at_align__ int64_t result_buf[int64_vec_size];
src.store(src_buf);
for (int i = 0; i < 2; i++) {
for (int j = 0; j < int64_vec_size; j++) {
result_buf[j] = static_cast<int64_t>(src_buf[i * int64_vec_size + j]);
}
result[i] = at::vec::Vectorized<int64_t>::loadu(result_buf);
}
# endif
return result;
}
inline at::vec::Vectorized<int32_t> cvt_int64_to_int32(at::vec::VectorizedN<int64_t,2> src) {
# if defined(CPU_CAPABILITY_AVX512)
auto low = _mm512_cvtepi64_epi32(src[0]);
auto high = _mm512_cvtepi64_epi32(src[1]);
return _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1);
# elif defined(CPU_CAPABILITY_AVX2)
auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0));
auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0));
auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0));
auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0));
return _mm256_blend_epi32(low_perm, high_perm, 0xF0);
# else
constexpr int int32_vec_size = at::vec::Vectorized<int32_t>::size();
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
__at_align__ int32_t result[int32_vec_size];
__at_align__ int64_t src_buf[int64_vec_size];
for (int i = 0; i < 2; i++) {
src[i].store(src_buf + i * int64_vec_size);
for (int j = 0; j < int64_vec_size; j++) {
result[i * int64_vec_size + j] = static_cast<int32_t>(src_buf[i * int64_vec_size + j]);
}
}
return at::vec::Vectorized<int32_t>::loadu(result);
# endif
}
inline at::vec::VectorizedN<int64_t,2> cvt_int32_to_int64(at::vec::Vectorized<int32_t> src) {
at::vec::VectorizedN<int64_t,2> result;
# if defined(CPU_CAPABILITY_AVX512)
result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src));
result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src, 1));
# elif defined(CPU_CAPABILITY_AVX2)
result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src));
result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src, 1));
#else
constexpr int int32_vec_size = at::vec::Vectorized<int32_t>::size();
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
__at_align__ int32_t src_buf[int32_vec_size];
__at_align__ int64_t result_buf[int64_vec_size];
src.store(src_buf);
for (int i = 0; i < 2; i++) {
for (int j = 0; j < int64_vec_size; j++) {
result_buf[j] = static_cast<int64_t>(src_buf[i * int64_vec_size + j]);
}
result[i] = at::vec::Vectorized<int64_t>::loadu(result_buf);
}
# endif
return result;
}
inline at::vec::VectorizedN<int64_t,2> mask_convert_to_int64(at::vec::Vectorized<float> src) {
return cvt_fp32_to_int64(mask_convert_to_float(src));
}
inline at::vec::Vectorized<float> to_float_mask(at::vec::VectorizedN<int64_t,2> src) {
return to_float_mask(cvt_int64_to_int32(src));
}
#endif

View File

@ -219,6 +219,8 @@ dtype_abbrs = {
torch.int64: 'i64',
torch.bool: 'b8',
torch.uint8: 'u8',
torch.uint32: 'u32',
torch.uint64: 'u64',
}
@compatibility(is_backward_compatible=True)