diff --git a/.lintrunner.toml b/.lintrunner.toml index 376d916e3c6..ec62529d1f4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -55,6 +55,7 @@ init_command = [ code = 'CLANGFORMAT' include_patterns = [ 'aten/src/ATen/*.h', + 'aten/src/ATen/cpu/vec/*.h', 'aten/src/ATen/mps/**/*.mm', 'aten/src/ATen/mps/**/*.h', 'aten/src/ATen/xpu/**/*.h', diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index 4d1d05ea8d3..e7429d18712 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -29,16 +29,21 @@ inline scalar_t vec_reduce_all( template struct VecReduceAllSIMD { - static inline scalar_t apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline scalar_t apply( + const Op& vec_fun, + const Vectorized& acc_vec) { return vec_reduce_all(vec_fun, acc_vec, Vectorized::size()); } }; -#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) +#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \ + !defined(C10_MOBILE) #if defined(CPU_CAPABILITY_AVX2) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 128-bit shuffle @@ -57,7 +62,9 @@ struct VecReduceAllSIMD { #if defined(CPU_CAPABILITY_AVX512) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 256-bit shuffle @@ -76,25 +83,33 @@ struct VecReduceAllSIMD { } }; #endif // defined(CPU_CAPABILITY_AVX512) -#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) +#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && + // !defined(C10_MOBILE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; - // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -] + // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, + // a4+a8, a1+a5, a2+a6, -, -, -, -] float32x4_t v1_1 = vextq_f32(v, v, 2); Vec v1 = v1_1; // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] v = vec_fun(v, v1); - // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -] + // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, + // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, + // -] v1_1 = vrev64q_f32(v); v1 = v1_1; - // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] + // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, + // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] v = vec_fun(v, v1); return v[0]; @@ -102,10 +117,13 @@ struct VecReduceAllSIMD { }; #endif // defined(__aarch64__) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + defined(CPU_CAPABILITY_SVE256) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 128-bit shuffle @@ -125,15 +143,21 @@ struct VecReduceAllSIMD { }; #endif // defined(__aarch64__) - template -inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized& acc_vec) { +inline scalar_t vec_reduce_all( + const Op& vec_fun, + const Vectorized& acc_vec) { return VecReduceAllSIMD::apply(vec_fun, acc_vec); } -template , int> = 0> -inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline scalar_t reduce_all( + const Op& vec_fun, + const scalar_t* data, + int64_t size) { using Vec = vec::Vectorized; if (size < Vec::size()) return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); @@ -151,16 +175,22 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size } // similar to reduce_all, but reduces into two outputs -template , int> = 0> -inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, - const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { using Vec = vec::Vectorized; if (size < Vec::size()) { auto loaded_data = Vec::loadu(data, size); return std::pair( - vec_reduce_all(vec_fun1, loaded_data, size), - vec_reduce_all(vec_fun2, loaded_data, size)); + vec_reduce_all(vec_fun1, loaded_data, size), + vec_reduce_all(vec_fun2, loaded_data, size)); } int64_t d = Vec::size(); Vec acc_vec1 = Vec::loadu(data); @@ -176,12 +206,14 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); } return std::pair( - vec_reduce_all(vec_fun1, acc_vec1), - vec_reduce_all(vec_fun2, acc_vec2)); + vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2)); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -205,8 +237,11 @@ inline scalar_t map_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map2_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -237,8 +272,11 @@ inline scalar_t map2_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map3_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -274,8 +312,10 @@ inline scalar_t map3_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -293,8 +333,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map2( const Op& vec_fun, scalar_t* output_data, @@ -317,8 +359,10 @@ inline void map2( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map3( const Op& vec_fun, scalar_t* output_data, @@ -344,8 +388,10 @@ inline void map3( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map4( const Op& vec_fun, scalar_t* output_data, diff --git a/aten/src/ATen/cpu/vec/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h index 3bd22b3820f..d4a40acaeef 100644 --- a/aten/src/ATen/cpu/vec/functional_bfloat16.h +++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h @@ -8,86 +8,120 @@ namespace at::vec { // BFloat16 specification -template struct VecScalarType { using type = scalar_t; }; -template <> struct VecScalarType { using type = float; }; -template <> struct VecScalarType { using type = float; }; +template +struct VecScalarType { + using type = scalar_t; +}; +template <> +struct VecScalarType { + using type = float; +}; +template <> +struct VecScalarType { + using type = float; +}; // This is different from at::acc_type since we only need to specialize BFloat16 template using vec_scalar_t = typename VecScalarType::type; // Vector conversion between float and bfloat16/half -template , int> = 0> -inline std::tuple, Vectorized> convert_to_float(const Vectorized&); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized&); template <> -inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) { +inline std::tuple, Vectorized> convert_to_float< + BFloat16>(const Vectorized& a) { return convert_bfloat16_float(a); } template <> -inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) { - return convert_half_float(a); +inline std::tuple, Vectorized> convert_to_float( + const Vectorized& a) { + return convert_half_float(a); } -template , int> = 0> -inline Vectorized convert_from_float(const Vectorized&, const Vectorized&); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline Vectorized convert_from_float( + const Vectorized&, + const Vectorized&); template <> -inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) { +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { return convert_float_bfloat16(a, b); } template <> -inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) { +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { return convert_float_half(a, b); } -template , int> = 0> -inline void load_to_float(const scalar_t *data, Vectorized &out1, Vectorized &out2); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float( + const scalar_t* data, + Vectorized& out1, + Vectorized& out2); template <> -inline void load_to_float (const BFloat16 *data, Vectorized &out1, Vectorized &out2) { +inline void load_to_float( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { load_fp32_from_bf16(data, out1, out2); } template <> -inline void load_to_float (const Half *data, Vectorized &out1, Vectorized &out2) { +inline void load_to_float( + const Half* data, + Vectorized& out1, + Vectorized& out2) { load_fp32_from_fp16(data, out1, out2); } -template , int> = 0> -inline void load_to_float(const scalar_t *data, Vectorized &out); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float(const scalar_t* data, Vectorized& out); template <> -inline void load_to_float (const BFloat16 *data, Vectorized &out) { +inline void load_to_float( + const BFloat16* data, + Vectorized& out) { load_fp32_from_bf16(data, out); } template <> -inline void load_to_float (const Half *data, Vectorized &out) { +inline void load_to_float(const Half* data, Vectorized& out) { load_fp32_from_fp16(data, out); } -// Note that we already have specialized member of Vectorized for BFloat16 -// so the following functions would run smoothly: +// Note that we already have specialized member of Vectorized for +// BFloat16 so the following functions would run smoothly: // using Vec = Vectorized; // Vec one = Vec(BFloat16(1)); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // // Then why we still need to specialize "functional"? -// If we do specialization at Vectorized<> level, the above example would need 3 pairs of -// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". -// If we do specialization at vec::map<>() level, we have only 1 pair of conversion -// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. +// If we do specialization at Vectorized<> level, the above example would need +// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and +// "/". If we do specialization at vec::map<>() level, we have only 1 pair of +// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16 +// vector only. // -// The following BFloat16 functionality will only do data type conversion for input -// and output vector (reduce functionality will only convert the final scalar back to bf16). -// Compared to Vectorized<> specialization, +// The following BFloat16 functionality will only do data type conversion for +// input and output vector (reduce functionality will only convert the final +// scalar back to bf16). Compared to Vectorized<> specialization, // 1. better performance since we have less data type conversion; // 2. less rounding error since immediate results are kept in fp32; // 3. accumulation done on data type of fp32. @@ -95,8 +129,10 @@ inline void load_to_float (const Half *data, Vectorized &out) { // If you plan to extend this file, please ensure adding unit tests at // aten/src/ATen/test/vec_test_all_types.cpp // -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { using bVec = vec::Vectorized; using fVec = vec::Vectorized; @@ -104,7 +140,8 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { bVec data_bvec = bVec::loadu(data, size); auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size > fVec::size()) { - data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(vec_fun, data_fvec0, fVec::size()); } else { return vec_reduce_all(vec_fun, data_fvec0, size); @@ -124,27 +161,37 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size - d > fVec::size()) { acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { - acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(vec_fun, acc_fvec0); } -template , int> = 0> -inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, - const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { using bVec = vec::Vectorized; using fVec = vec::Vectorized; if (size < bVec::size()) { bVec data_bvec = bVec::loadu(data, size); auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size > fVec::size()) { - fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); - fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc1_fvec = fVec::set( + data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc2_fvec = fVec::set( + data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); return std::pair( vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()), vec_reduce_all(vec_fun2, acc2_fvec, fVec::size())); @@ -171,12 +218,20 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_f auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size - d > fVec::size()) { acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); - acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size()); + acc1_fvec1 = fVec::set( + acc1_fvec1, + vec_fun1(acc1_fvec1, data_fvec1), + size - d - fVec::size()); acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); - acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size()); + acc2_fvec1 = fVec::set( + acc2_fvec1, + vec_fun2(acc2_fvec1, data_fvec1), + size - d - fVec::size()); } else { - acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); - acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); + acc1_fvec0 = + fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); + acc2_fvec0 = + fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); } } acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); @@ -186,8 +241,11 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_f vec_reduce_all(vec_fun2, acc2_fvec0)); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -201,7 +259,8 @@ inline float map_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0); data_fvec1 = map_fun(data_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0); @@ -228,18 +287,23 @@ inline float map_reduce_all( data_fvec0 = map_fun(data_fvec0); data_fvec1 = map_fun(data_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map2_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -256,7 +320,8 @@ inline float map2_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0); @@ -289,18 +354,23 @@ inline float map2_reduce_all( data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map3_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -320,7 +390,8 @@ inline float map3_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); @@ -359,18 +430,22 @@ inline float map3_reduce_all( data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -397,8 +472,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -419,7 +496,8 @@ inline void map( fVec data_fvec0, data_fvec1; if (size - d > fVec::size()) { data_fvec0 = fVec::loadu(input_data + d); - data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); + data_fvec1 = + fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); } else { // choose to align with behaviour of bVec::loadu(ptr, size), // which leaves data_fvec1 uninitialized @@ -432,8 +510,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map2( const Op& vec_fun, scalar_t* output_data, @@ -465,8 +545,10 @@ inline void map2( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map3( const Op& vec_fun, scalar_t* output_data, @@ -503,8 +585,10 @@ inline void map3( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map4( const Op& vec_fun, scalar_t* output_data, @@ -525,8 +609,10 @@ inline void map4( auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); bVec data4_bvec = bVec::loadu(input_data4 + d); auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); - fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); - fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); output_bvec.store(output_data + d); } @@ -539,8 +625,10 @@ inline void map4( auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); - fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); - fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); output_bvec.store(output_data + d, size - d); } diff --git a/aten/src/ATen/cpu/vec/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h index 48b18793b07..f9086f7d3d0 100644 --- a/aten/src/ATen/cpu/vec/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -13,10 +13,14 @@ /* Microsoft C/C++-compatible compiler */ #include #if _MSC_VER <= 1900 -#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) -#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) -#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) -#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) +#define _mm256_extract_epi64(X, Y) \ + (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) +#define _mm256_extract_epi32(X, Y) \ + (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) +#define _mm256_extract_epi16(X, Y) \ + (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) +#define _mm256_extract_epi8(X, Y) \ + (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) #endif #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ @@ -25,9 +29,9 @@ /* GCC-compatible compiler, targeting ARM with SVE */ #include #endif -#if defined (MISSING_ARM_VLD1) +#if defined(MISSING_ARM_VLD1) #include -#elif defined (MISSING_ARM_VST1) +#elif defined(MISSING_ARM_VST1) #include #endif #elif defined(__GNUC__) && defined(__IWMMXT__) @@ -36,8 +40,8 @@ #elif defined(__s390x__) // targets Z/architecture // we will include vecintrin later -#elif (defined(__GNUC__) || defined(__xlC__)) && \ - (defined(__VEC__) || defined(__ALTIVEC__)) +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ #include /* We need to undef those tokens defined by to avoid conflicts diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index e4b0c4b95d8..0bfe65cd195 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -28,21 +28,30 @@ inline Vectorized Vectorized::loadu(const void* ptr) { } template <> -inline Vectorized Vectorized::loadu(const void* ptr, int64_t count) { +inline Vectorized Vectorized::loadu( + const void* ptr, + int64_t count) { // See NOTE [Loading boolean values] return convert_to_bool(Vectorized::loadu(ptr, count)); } template -struct VecHoldType { using hold_type = typename VT::value_type; }; +struct VecHoldType { + using hold_type = typename VT::value_type; +}; template <> -struct VecHoldType> { using hold_type = BFloat16; }; +struct VecHoldType> { + using hold_type = BFloat16; +}; template <> -struct VecHoldType> {using hold_type = Half; }; +struct VecHoldType> { + using hold_type = Half; +}; template using vechold_type = typename VecHoldType::hold_type; -}} // namespace at::vec::CPU_CAPABILITY +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 2591338881a..3e6124cbc50 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -1,5 +1,6 @@ #pragma once -#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && defined(__ARM_FEATURE_SVE) +#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ + defined(__ARM_FEATURE_SVE) // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 #pragma GCC optimize("no-tree-vectorize") #endif @@ -18,27 +19,27 @@ // See https://github.com/pytorch/pytorch/issues/37577 for an instance // of this bug in the past. -#include #include +#include #include +#include +#include #include #include -#include #include -#include +#include #include #include -#include -#include -#include -#include -#include #include -#include #include -#include +#include +#include +#include #include +#include +#include +#include #if defined(__GNUC__) #define __FORCE_INLINE __attribute__((always_inline)) inline @@ -66,7 +67,8 @@ Windows llvm will not have this definition. #endif #define VECTOR_WIDTH 64 #define int_vector __m512i -#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +#elif defined(__aarch64__) && \ + !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 // SVE code expects 256-vectors; leave that set for SVE? #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(16))) @@ -93,40 +95,43 @@ namespace at::vec { inline namespace CPU_CAPABILITY { // at::Half and at::BFloat16 should be treated as floating point template -struct is_floating_point: - std::integral_constant || - std::is_same_v || - std::is_same_v> { -}; +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point_v || std::is_same_v || + std::is_same_v> {}; -template +template constexpr bool is_floating_point_v = is_floating_point::value; template -struct is_reduced_floating_point: - std::integral_constant || - std::is_same_v> { -}; +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; template -constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; template -struct is_8bit_integer: - std::integral_constant || - std::is_same_v> { +struct is_8bit_integer + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> { }; template constexpr bool is_8bit_integer_v = is_8bit_integer::value; -template struct int_of_size; +template +struct int_of_size; -#define DEFINE_INT_OF_SIZE(int_t) \ -template<> struct int_of_size { using type = int_t; } +#define DEFINE_INT_OF_SIZE(int_t) \ + template <> \ + struct int_of_size { \ + using type = int_t; \ + } DEFINE_INT_OF_SIZE(int64_t); DEFINE_INT_OF_SIZE(int32_t); @@ -142,14 +147,15 @@ using int_same_size_t = typename int_of_size::type; // emulates Vectorized types #if defined(__s390x__) -template +template #else template #endif struct Vectorized { -private: + private: __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; -public: + + public: using value_type = T; using size_type = int; @@ -163,11 +169,11 @@ public: values[i] = val; } } - template> - Vectorized(Args... vals) : values{vals...}{ - } - Vectorized(const T(&arr)[kSize]) { + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) : values{vals...} {} + Vectorized(const T (&arr)[kSize]) { std::memcpy(values, arr, sizeof(values)); } // This also implies const T& operator[](int idx) const @@ -198,20 +204,23 @@ public: } // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 #if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) - static Vectorized __attribute__ ((optimize("-fno-tree-loop-vectorize"))) blendv(const Vectorized& a, + static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) + blendv( + const Vectorized& a, #else - static Vectorized blendv(const Vectorized& a, + static Vectorized blendv( + const Vectorized& a, #endif - const Vectorized& b, const Vectorized& mask) { + const Vectorized& b, + const Vectorized& mask) { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); #if defined(__clang__) && __ARM_FEATURE_SVE - #pragma clang loop vectorize(disable) +#pragma clang loop vectorize(disable) #endif for (const auto i : c10::irange(size())) { - if (buffer[i] & 0x01) - { + if (buffer[i] & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; @@ -219,15 +228,21 @@ public: } return vector; } - template // step sometimes requires a higher precision type (e.g., T=int, step_t=double) - static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) { + template // step sometimes requires a higher precision type + // (e.g., T=int, step_t=double) + static Vectorized arange( + T base = static_cast(0), + step_t step = static_cast(1)) { Vectorized vector; for (const auto i : c10::irange(size())) { vector.values[i] = base + i * step; } return vector; } - static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { Vectorized vector; for (const auto i : c10::irange(size())) { if (i < count) { @@ -249,7 +264,9 @@ public: return vector; } static Vectorized loadu_one_fourth(const void* ptr) { - static_assert(std::is_same_v || std::is_same_v, "For byte types only"); + static_assert( + std::is_same_v || std::is_same_v, + "For byte types only"); return Vectorized::loadu(ptr, 8); } @@ -257,9 +274,10 @@ public: std::memcpy(ptr, values, count * sizeof(T)); } int zero_mask() const { - // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit int mask = 0; - for (int i = 0; i < size(); ++ i) { + for (int i = 0; i < size(); ++i) { if (values[i] == static_cast(0)) { mask |= (1 << i); } @@ -279,15 +297,18 @@ public: } bool has_inf_nan() const { for (int64_t i = 0; i != size(); i++) { - if(_isnan(values[i]) || _isinf(values[i])) { + if (_isnan(values[i]) || _isinf(values[i])) { return true; } } return false; } -// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows Arm64 -// See https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 -#if defined(_WIN32) && defined(__aarch64__) && ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) +// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows +// Arm64 +// See +// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) && \ + ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i < size(); i++) { @@ -322,38 +343,44 @@ public: return ret; } #endif - Vectorized map(T (*const f)(const T &)) const { + Vectorized map(T (*const f)(const T&)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } - T reduce(T (*const f)(const T &)) const { + T reduce(T (*const f)(const T&)) const { T ret = 0; for (int64_t i = 0; i != size(); i++) { ret = f(ret, values[i]); } return ret; } - template && !c10::is_complex::value, int> = 0> + template < + typename other_t_abs = T, + typename std::enable_if_t< + !is_floating_point_v && + !c10::is_complex::value, + int> = 0> Vectorized abs() const { // other_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_abs must be T"); return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); } - template , int> = 0> + template < + typename float_t_abs = T, + typename std::enable_if_t, int> = 0> Vectorized abs() const { // float_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "float_t_abs must be T"); - // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in - // 0.0) properly. + // Specifically deal with floating-point because the generic code above + // won't handle -0.0 (which should result in 0.0) properly. return map([](T x) -> T { return std::abs(x); }); } - template ::value, int> = 0> + template < + typename complex_t_abs = T, + typename std::enable_if_t::value, int> = 0> Vectorized abs() const { // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "complex_t_abs must be T"); @@ -361,66 +388,85 @@ public: return map([](T x) { return static_cast(std::abs(x)); }); } - template ::value, int> = 0> + template < + typename other_t_sgn = T, + typename std::enable_if_t::value, int> = 0> Vectorized sgn() const { return map(at::native::sgn_impl); } - template ::value, int> = 0> + template < + typename other_t_angle = T, + typename std::enable_if_t::value, int> = + 0> Vectorized angle() const { // other_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_angle must be T"); - return map(at::native::angle_impl); // compiler is unable to resolve the overload without + return map(at::native::angle_impl); // compiler is unable to resolve the + // overload without } - template ::value, int> = 0> + template < + typename complex_t_angle = T, + typename std::enable_if_t::value, int> = + 0> Vectorized angle() const { // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_angle must be T"); + static_assert( + std::is_same_v, "complex_t_angle must be T"); return map([](T x) { return static_cast(std::arg(x)); }); } - template ::value, int> = 0> + template < + typename other_t_real = T, + typename std::enable_if_t::value, int> = 0> Vectorized real() const { // other_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_real must be T"); return *this; } - template ::value, int> = 0> + template < + typename complex_t_real = T, + typename std::enable_if_t::value, int> = + 0> Vectorized real() const { // complex_t_real is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_real must be T"); + static_assert( + std::is_same_v, "complex_t_real must be T"); return map([](T x) { return static_cast(x.real()); }); } - template ::value, int> = 0> + template < + typename other_t_imag = T, + typename std::enable_if_t::value, int> = 0> Vectorized imag() const { // other_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_imag must be T"); return Vectorized(0); } - template ::value, int> = 0> + template < + typename complex_t_imag = T, + typename std::enable_if_t::value, int> = + 0> Vectorized imag() const { // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_imag must be T"); + static_assert( + std::is_same_v, "complex_t_imag must be T"); return map([](T x) { return static_cast(x.imag()); }); } - template ::value, int> = 0> + template < + typename other_t_conj = T, + typename std::enable_if_t::value, int> = 0> Vectorized conj() const { // other_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_conj must be T"); return *this; } - template ::value, int> = 0> + template < + typename complex_t_conj = T, + typename std::enable_if_t::value, int> = + 0> Vectorized conj() const { // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_conj must be T"); + static_assert( + std::is_same_v, "complex_t_conj must be T"); return map([](T x) { return static_cast(std::conj(x)); }); } Vectorized acos() const { @@ -441,7 +487,7 @@ public: Vectorized atanh() const { return map(std::atanh); } - Vectorized atan2(const Vectorized &exp) const { + Vectorized atan2(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::atan2(values[i], exp[i]); @@ -449,9 +495,9 @@ public: return ret; } template < - typename U = T, - typename std::enable_if_t, int> = 0> - Vectorized copysign(const Vectorized &sign) const { + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized copysign(const Vectorized& sign) const { Vectorized ret; for (size_type i = 0; i < size(); i++) { ret[i] = c10::copysign(values[i], sign[i]); @@ -483,8 +529,8 @@ public: return *this - this->trunc(); } template < - typename U = T, - typename std::enable_if_t, int> = 0> + typename U = T, + typename std::enable_if_t, int> = 0> Vectorized fmod(const Vectorized& q) const { // U is for SFINAE purposes only. Make sure it is not changed. static_assert(std::is_same_v, "U must be T"); @@ -503,20 +549,24 @@ public: Vectorized log1p() const { return map(std::log1p); } - template ::value, int> = 0> + template < + typename other_t_log2 = T, + typename std::enable_if_t::value, int> = 0> Vectorized log2() const { // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_log2 must be T"); return map(std::log2); } - template ::value, int> = 0> + template < + typename complex_t_log2 = T, + typename std::enable_if_t::value, int> = + 0> Vectorized log2() const { // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_log2 must be T"); + static_assert( + std::is_same_v, "complex_t_log2 must be T"); const T log_2 = T(std::log(2.0)); - return Vectorized(map(std::log))/Vectorized(log_2); + return Vectorized(map(std::log)) / Vectorized(log_2); } Vectorized ceil() const { return map(at::native::ceil_impl); @@ -530,7 +580,7 @@ public: Vectorized floor() const { return map(at::native::floor_impl); } - Vectorized hypot(const Vectorized &b) const { + Vectorized hypot(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::hypot(values[i], b[i]); @@ -546,14 +596,14 @@ public: Vectorized digamma() const { return map(calc_digamma); } - Vectorized igamma(const Vectorized &x) const { + Vectorized igamma(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igamma(values[i], x[i]); } return ret; } - Vectorized igammac(const Vectorized &x) const { + Vectorized igammac(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igammac(values[i], x[i]); @@ -566,7 +616,7 @@ public: // promotion return map([](T x) -> T { return -x; }); } - Vectorized nextafter(const Vectorized &b) const { + Vectorized nextafter(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::nextafter(values[i], b[i]); @@ -574,7 +624,8 @@ public: return ret; } Vectorized round() const { - // We do not use std::round because we would like to round midway numbers to the nearest even integer. + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. return map(at::native::round_impl); } Vectorized sin() const { @@ -604,20 +655,21 @@ public: Vectorized rsqrt() const { return map([](T x) { return (T)1 / std::sqrt(x); }); } - Vectorized pow(const Vectorized &exp) const { + Vectorized pow(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::pow(values[i], exp[i]); } return ret; } - T reduce_add() const { + T reduce_add() const { return reduce([](T x, T y) -> T { return x + y; }); } T reduce_max() const { return reduce(std::max); } -private: + + private: template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. @@ -632,35 +684,61 @@ private: return vector; } -public: - Vectorized operator==(const Vectorized& other) const { return binary_pred(other, std::equal_to()); } - Vectorized operator!=(const Vectorized& other) const { return binary_pred(other, std::not_equal_to()); } - Vectorized operator>=(const Vectorized& other) const { return binary_pred(other, std::greater_equal()); } - Vectorized operator<=(const Vectorized& other) const { return binary_pred(other, std::less_equal()); } - Vectorized operator>(const Vectorized& other) const { return binary_pred(other, std::greater()); } - Vectorized operator<(const Vectorized& other) const { return binary_pred(other, std::less()); } + public: + Vectorized operator==(const Vectorized& other) const { + return binary_pred(other, std::equal_to()); + } + Vectorized operator!=(const Vectorized& other) const { + return binary_pred(other, std::not_equal_to()); + } + Vectorized operator>=(const Vectorized& other) const { + return binary_pred(other, std::greater_equal()); + } + Vectorized operator<=(const Vectorized& other) const { + return binary_pred(other, std::less_equal()); + } + Vectorized operator>(const Vectorized& other) const { + return binary_pred(other, std::greater()); + } + Vectorized operator<(const Vectorized& other) const { + return binary_pred(other, std::less()); + } -private: + private: template - inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { + inline Vectorized binary_pred_bool(const Vectorized& other, Op op) + const { // 1 if the pred is true, otherwise 0. Vectorized vector; - for (int i = 0; i != size(); ++ i) { + for (int i = 0; i != size(); ++i) { vector[i] = static_cast(op(values[i], other.values[i])); } return vector; } -public: - Vectorized eq(const Vectorized& other) const { return binary_pred_bool(other, std::equal_to()); } - Vectorized ne(const Vectorized& other) const { return binary_pred_bool(other, std::not_equal_to()); } - Vectorized gt(const Vectorized& other) const { return binary_pred_bool(other, std::greater()); } - Vectorized ge(const Vectorized& other) const { return binary_pred_bool(other, std::greater_equal()); } - Vectorized lt(const Vectorized& other) const { return binary_pred_bool(other, std::less()); } - Vectorized le(const Vectorized& other) const { return binary_pred_bool(other, std::less_equal()); } + public: + Vectorized eq(const Vectorized& other) const { + return binary_pred_bool(other, std::equal_to()); + } + Vectorized ne(const Vectorized& other) const { + return binary_pred_bool(other, std::not_equal_to()); + } + Vectorized gt(const Vectorized& other) const { + return binary_pred_bool(other, std::greater()); + } + Vectorized ge(const Vectorized& other) const { + return binary_pred_bool(other, std::greater_equal()); + } + Vectorized lt(const Vectorized& other) const { + return binary_pred_bool(other, std::less()); + } + Vectorized le(const Vectorized& other) const { + return binary_pred_bool(other, std::less_equal()); + } }; -template Vectorized inline operator+(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] + b[i]; @@ -668,7 +746,8 @@ template Vectorized inline operator+(const Vectorized &a, const return c; } -template Vectorized inline operator-(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] - b[i]; @@ -676,7 +755,8 @@ template Vectorized inline operator-(const Vectorized &a, const return c; } -template Vectorized inline operator*(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] * b[i]; @@ -684,7 +764,9 @@ template Vectorized inline operator*(const Vectorized &a, const return c; } -template Vectorized inline operator/(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { +template +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] / b[i]; @@ -692,14 +774,16 @@ template Vectorized inline operator/(const Vectorized &a, const return c; } -template , int> = 0> -Vectorized inline operator%(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { +template , int> = 0> +Vectorized inline operator%(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { return a - a / b * b; } -template Vectorized inline operator||( - const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator||( + const Vectorized& a, + const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] || b[i]; @@ -709,9 +793,10 @@ template Vectorized inline operator||( // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. -template ::value, int> = 0> -Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; @@ -725,9 +810,10 @@ Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; @@ -743,9 +829,10 @@ Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. -template ::value, int> = 0> -Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; @@ -759,9 +846,10 @@ Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; @@ -775,9 +863,13 @@ Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, const Vectorized &max_vec) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_vec, + const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); @@ -785,9 +877,12 @@ Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, return c; } -template ::value, int> = 0> -Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_vec) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; @@ -795,9 +890,12 @@ Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_ return c; } -template ::value, int> = 0> -Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_vec) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; @@ -809,14 +907,21 @@ struct Vectorizedi; #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template -static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { int_vector buffer; #if defined(CPU_CAPABILITY_AVX2) - int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); - int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); + int_vector a_buffer = + _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm256_load_si256(reinterpret_cast((const T*)b)); #elif defined(CPU_CAPABILITY_AVX512) - int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); - int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); + int_vector a_buffer = + _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm512_load_si512(reinterpret_cast((const T*)b)); #endif buffer = op(a_buffer, b_buffer); __at_align__ T results[Vectorized::size()]; @@ -829,31 +934,52 @@ static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vect return Vectorized::loadu(results); } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); #endif } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); #endif } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); #endif } @@ -866,12 +992,19 @@ auto load(char const* data) -> T { return ret; } -template -static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); __at_align__ intmax_t buffer[element_no]; - static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); - static_assert(sizeof(buffer) == sizeof(Vectorized), "sizeof(buffer) must match sizeof(Vectorized)"); + static_assert( + VECTOR_WIDTH % sizeof(intmax_t) == 0, + "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); + static_assert( + sizeof(buffer) == sizeof(Vectorized), + "sizeof(buffer) must match sizeof(Vectorized)"); // We should be using memcpy in order to respect the strict aliasing rule // see: https://github.com/pytorch/pytorch/issues/66119 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 @@ -889,34 +1022,50 @@ static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vect return Vectorized::loadu(buffer); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_and()); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_or()); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator~(const Vectorized& a) { using int_t = int_same_size_t; - Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 + Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 return a ^ ones; } -template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { constexpr T max_shift = sizeof(T) * CHAR_BIT; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; - if ((static_cast>(shift) < 0) || (shift >= max_shift)) { + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { c[i] = 0; } else { c[i] = static_cast>(a[i]) << shift; @@ -925,13 +1074,17 @@ template Vectorized inline operator<<(const Vectorized &a, const return c; } -template Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { // right shift value to retain sign bit for signed and no bits for unsigned constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; - if ((static_cast>(shift) < 0) || (shift >= max_shift)) { + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { c[i] = a[i] >> max_shift; } else { c[i] = a[i] >> shift; @@ -941,50 +1094,56 @@ template Vectorized inline operator>>(const Vectorized &a, const } template -inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { a = a + b; return a; } template -inline Vectorized& operator -= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { a = a - b; return a; } template -inline Vectorized& operator /= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { a = a / b; return a; } template -inline Vectorized& operator %= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { a = a % b; return a; } template -inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { a = a * b; return a; } template -inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { a = a << b; return a; } template -inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { a = a >> b; return a; } template -inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { +inline Vectorized fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { return a * b + c; } template -inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) { +inline Vectorized fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { return a * b - c; } @@ -1000,8 +1159,10 @@ Vectorized inline operator&&( } template -std::enable_if_t> -inline gather(T const* base_addr, const Vectorized>& vindex) { +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + T>> inline gather(T const* base_addr, const Vectorized>& vindex) { static constexpr int size = Vectorized::size(); int_same_size_t index_arr[size]; vindex.store(static_cast(index_arr)); @@ -1013,36 +1174,39 @@ inline gather(T const* base_addr, const Vectorized>& vindex) } template -std::enable_if_t> -inline mask_gather(const Vectorized& src, T const* base_addr, - const Vectorized>& vindex, Vectorized& mask) { +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + T const* base_addr, + const Vectorized>& vindex, + Vectorized& mask) { static constexpr int size = Vectorized::size(); T src_arr[size]; - int_same_size_t mask_arr[size]; // use int type so we can logical and + int_same_size_t mask_arr[size]; // use int type so we can logical and int_same_size_t index_arr[size]; src.store(static_cast(src_arr)); mask.store(static_cast(mask_arr)); vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { - if (mask_arr[i] & 0x01) { // check highest bit + if (mask_arr[i] & 0x01) { // check highest bit buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } else { buffer[i] = src_arr[i]; } } - mask = Vectorized(static_cast(0)); // "zero out" mask + mask = Vectorized(static_cast(0)); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } // Cast a given vector to another type without changing the bits representation. // So a Vectorized of 512 bits containing all ones can be cast to a -// Vectorized of 512 bits containing all ones (i.e., eight negative 1s). -// A Vec of 256 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative +// 1s). A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). // There is a struct here because we don't have static_if and I can't // partially specialize a templated function. -template +template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { src_t src_arr[Vectorized::size()]; @@ -1051,44 +1215,51 @@ struct CastImpl { } }; -template +template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { return src; } }; -template +template inline Vectorized cast(const Vectorized& src) { return CastImpl::apply(src); } template > -inline Vectorized convert_to_int_of_same_size(const Vectorized& src) { +inline Vectorized convert_to_int_of_same_size( + const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; - std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), - [](const T& x) { return static_cast(x); }); + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { + return static_cast(x); + }); return Vectorized::loadu(static_cast(buffer.data())); } template > -inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) { +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; - std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), - [](const IntType& x) { return static_cast(x); }); + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { + return static_cast(x); + }); return Vectorized::loadu(static_cast(buffer.data())); } +// clang-format off // Example inputs for AVX512: // a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} @@ -1099,8 +1270,11 @@ inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) // b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} // returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// clang-format on template -inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; @@ -1116,10 +1290,12 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { buffer2[i] = a_arr[i * 2 + 1]; buffer2[half_size + i] = b_arr[i * 2 + 1]; } - return std::make_pair(Vectorized::loadu(static_cast(buffer1)), - Vectorized::loadu(static_cast(buffer2))); + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); } +// clang-format off // inverse operation of deinterleave2 // Example inputs for AVX512: // a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} @@ -1131,8 +1307,11 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// clang-format on template -inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> interleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; @@ -1148,14 +1327,15 @@ interleave2(const Vectorized& a, const Vectorized& b) { buffer2[i * 2] = a_arr[half_size + i]; buffer2[i * 2 + 1] = b_arr[half_size + i]; } - return std::make_pair(Vectorized::loadu(static_cast(buffer1)), - Vectorized::loadu(static_cast(buffer2))); + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); } template -inline void convert(const src_T *src, dst_T *dst, int64_t n) { +inline void convert(const src_T* src, dst_T* dst, int64_t n) { #ifndef _MSC_VER -# pragma unroll +#pragma unroll #endif for ([[maybe_unused]] const auto i : c10::irange(n)) { *dst = c10::convert(c10::load(src)); @@ -1165,7 +1345,7 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) { } template -inline Vectorized flip(const Vectorized & data) { +inline Vectorized flip(const Vectorized& data) { static constexpr int size = Vectorized::size(); T output[size]; T buffer[size]; @@ -1176,25 +1356,37 @@ inline Vectorized flip(const Vectorized & data) { return Vectorized::loadu(static_cast(output)); } -// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading -// dimension of `src` and `ld_dst` is the leading dimension of `dst`. +// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. +// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading +// dimension of `dst`. template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) { +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst, + int M, + int N) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { - dst[j*ld_dst + i] = src[i*ld_src + j]; + dst[j * ld_dst + i] = src[i * ld_src + j]; } } } template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } -}} // namespace at::vec::CPU_CAPABILITY +} // namespace CPU_CAPABILITY +} // namespace at::vec // additional headers for more operations that depend on vec_base -#include -#include #include +#include +#include diff --git a/aten/src/ATen/cpu/vec/vec_convert.h b/aten/src/ATen/cpu/vec/vec_convert.h index a5cee03dabc..f5e5177908c 100644 --- a/aten/src/ATen/cpu/vec/vec_convert.h +++ b/aten/src/ATen/cpu/vec/vec_convert.h @@ -28,8 +28,8 @@ struct VecConvert { }; template -inline std::enable_if_t, Vectorized> -convert(const Vectorized& src) { +inline std::enable_if_t, Vectorized> convert( + const Vectorized& src) { return src; } diff --git a/aten/src/ATen/cpu/vec/vec_half.h b/aten/src/ATen/cpu/vec/vec_half.h index c7c90cc95b4..972d3ee3929 100644 --- a/aten/src/ATen/cpu/vec/vec_half.h +++ b/aten/src/ATen/cpu/vec/vec_half.h @@ -103,7 +103,9 @@ static inline void transpose_pad_2x32_block( _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); } #else -TORCH_CHECK(false, "transpose_pad_2x32_block is only supported when avx512 is supported") + TORCH_CHECK( + false, + "transpose_pad_2x32_block is only supported when avx512 is supported") #endif } @@ -124,28 +126,31 @@ static inline void pack_vnni2( for (; bk < _K; bk += 2) { int64_t bn = 0; for (; bn < _N; bn += 32) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); } int64_t nrem = N - bn; if (nrem > 0) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); } } if (K % 2 == 1) { int64_t bn = 0; for (; bn < _N; bn += 32) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); } int64_t nrem = N - bn; if (nrem > 0) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); } } #else -TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") + TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") #endif } - } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index c547e5911ec..e19d7f75388 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -68,7 +68,12 @@ struct VecMaskTo { } }; -template +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> struct VecMaskCast { static inline VecMask apply( const VecMask& vec_mask) { @@ -88,15 +93,17 @@ struct VecMaskCheck { static inline bool all_zero(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); - return std::all_of( - mask, mask + VectorizedN::size(), [](T m) { return m == static_cast(0); }); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m == static_cast(0); + }); } static inline bool all_masked(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); - return std::all_of( - mask, mask + VectorizedN::size(), [](T m) { return m != static_cast(0); }); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m != static_cast(0); + }); } static inline bool is_masked(const VectorizedN& vec_mask, int i) { @@ -159,13 +166,11 @@ class VecMask { } static VecMask blendv( - const VecMask& c, - const VecMask& b, - const VecMask& a) { + const VecMask& c, + const VecMask& b, + const VecMask& a) { VectorizedN result = VectorizedN::blendv( - VectorizedN(c), - VectorizedN(b), - VectorizedN(a)); + VectorizedN(c), VectorizedN(b), VectorizedN(a)); return result; } @@ -174,14 +179,14 @@ class VecMask { const VecMask& b, int64_t count = size()) { VectorizedN result = VectorizedN::set( - VectorizedN(a), - VectorizedN(b), - count); + VectorizedN(a), VectorizedN(b), count); return result; } void store(bool* b, int count = size()) { - constexpr int L = (VectorizedN::size() + Vectorized::size() - 1)/ Vectorized::size(); + constexpr int L = + (VectorizedN::size() + Vectorized::size() - 1) / + Vectorized::size(); auto res = this->to(); res.store(b, count); return;