clang-format aten/src/ATen/cpu/vec/*.h (#150426)

I got a complaint about indentation on #150380. Make the machines fix it for us.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150426
Approved by: https://github.com/aditew01, https://github.com/cyyever, https://github.com/frost-intel, https://github.com/Skylion007
This commit is contained in:
Scott Wolchok 2025-04-02 14:18:53 -07:00 committed by PyTorch MergeBot
parent bd9c42ebfb
commit ed0fd2fa7a
9 changed files with 735 additions and 385 deletions

View File

@ -55,6 +55,7 @@ init_command = [
code = 'CLANGFORMAT' code = 'CLANGFORMAT'
include_patterns = [ include_patterns = [
'aten/src/ATen/*.h', 'aten/src/ATen/*.h',
'aten/src/ATen/cpu/vec/*.h',
'aten/src/ATen/mps/**/*.mm', 'aten/src/ATen/mps/**/*.mm',
'aten/src/ATen/mps/**/*.h', 'aten/src/ATen/mps/**/*.h',
'aten/src/ATen/xpu/**/*.h', 'aten/src/ATen/xpu/**/*.h',

View File

@ -29,16 +29,21 @@ inline scalar_t vec_reduce_all(
template <typename scalar_t, typename Op> template <typename scalar_t, typename Op>
struct VecReduceAllSIMD { struct VecReduceAllSIMD {
static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) { static inline scalar_t apply(
const Op& vec_fun,
const Vectorized<scalar_t>& acc_vec) {
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size()); return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
} }
}; };
#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \
!defined(C10_MOBILE)
#if defined(CPU_CAPABILITY_AVX2) #if defined(CPU_CAPABILITY_AVX2)
template <typename Op> template <typename Op>
struct VecReduceAllSIMD<float, Op> { struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) { static inline float apply(
const Op& vec_fun,
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>; using Vec = Vectorized<float>;
Vec v = acc_vec; Vec v = acc_vec;
// 128-bit shuffle // 128-bit shuffle
@ -57,7 +62,9 @@ struct VecReduceAllSIMD<float, Op> {
#if defined(CPU_CAPABILITY_AVX512) #if defined(CPU_CAPABILITY_AVX512)
template <typename Op> template <typename Op>
struct VecReduceAllSIMD<float, Op> { struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) { static inline float apply(
const Op& vec_fun,
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>; using Vec = Vectorized<float>;
Vec v = acc_vec; Vec v = acc_vec;
// 256-bit shuffle // 256-bit shuffle
@ -76,25 +83,33 @@ struct VecReduceAllSIMD<float, Op> {
} }
}; };
#endif // defined(CPU_CAPABILITY_AVX512) #endif // defined(CPU_CAPABILITY_AVX512)
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
// !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE) #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE)
template <typename Op> template <typename Op>
struct VecReduceAllSIMD<float, Op> { struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) { static inline float apply(
const Op& vec_fun,
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>; using Vec = Vectorized<float>;
Vec v = acc_vec; Vec v = acc_vec;
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -] // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7,
// a4+a8, a1+a5, a2+a6, -, -, -, -]
float32x4_t v1_1 = vextq_f32(v, v, 2); float32x4_t v1_1 = vextq_f32(v, v, 2);
Vec v1 = v1_1; Vec v1 = v1_1;
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
v = vec_fun(v, v1); v = vec_fun(v, v1);
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -] // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -,
// -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -,
// -]
v1_1 = vrev64q_f32(v); v1_1 = vrev64q_f32(v);
v1 = v1_1; v1 = v1_1;
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8,
// a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
v = vec_fun(v, v1); v = vec_fun(v, v1);
return v[0]; return v[0];
@ -102,10 +117,13 @@ struct VecReduceAllSIMD<float, Op> {
}; };
#endif // defined(__aarch64__) #endif // defined(__aarch64__)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256) #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
defined(CPU_CAPABILITY_SVE256)
template <typename Op> template <typename Op>
struct VecReduceAllSIMD<float, Op> { struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) { static inline float apply(
const Op& vec_fun,
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>; using Vec = Vectorized<float>;
Vec v = acc_vec; Vec v = acc_vec;
// 128-bit shuffle // 128-bit shuffle
@ -125,15 +143,21 @@ struct VecReduceAllSIMD<float, Op> {
}; };
#endif // defined(__aarch64__) #endif // defined(__aarch64__)
template <typename scalar_t, typename Op> template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) { inline scalar_t vec_reduce_all(
const Op& vec_fun,
const Vectorized<scalar_t>& acc_vec) {
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec); return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t reduce_all(
const Op& vec_fun,
const scalar_t* data,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>; using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size()) if (size < Vec::size())
return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
@ -151,16 +175,22 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size
} }
// similar to reduce_all, but reduces into two outputs // similar to reduce_all, but reduces into two outputs
template <typename scalar_t, typename Op1, typename Op2, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, typename Op1,
const scalar_t* data, int64_t size) { typename Op2,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::pair<scalar_t, scalar_t> reduce2_all(
const Op1& vec_fun1,
const Op2& vec_fun2,
const scalar_t* data,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>; using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size()) { if (size < Vec::size()) {
auto loaded_data = Vec::loadu(data, size); auto loaded_data = Vec::loadu(data, size);
return std::pair<scalar_t, scalar_t>( return std::pair<scalar_t, scalar_t>(
vec_reduce_all(vec_fun1, loaded_data, size), vec_reduce_all(vec_fun1, loaded_data, size),
vec_reduce_all(vec_fun2, loaded_data, size)); vec_reduce_all(vec_fun2, loaded_data, size));
} }
int64_t d = Vec::size(); int64_t d = Vec::size();
Vec acc_vec1 = Vec::loadu(data); Vec acc_vec1 = Vec::loadu(data);
@ -176,12 +206,14 @@ inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2&
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
} }
return std::pair<scalar_t, scalar_t>( return std::pair<scalar_t, scalar_t>(
vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2));
vec_reduce_all(vec_fun2, acc_vec2));
} }
template <typename scalar_t, typename MapOp, typename ReduceOp, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename MapOp,
typename ReduceOp,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map_reduce_all( inline scalar_t map_reduce_all(
const MapOp& map_fun, const MapOp& map_fun,
const ReduceOp& red_fun, const ReduceOp& red_fun,
@ -205,8 +237,11 @@ inline scalar_t map_reduce_all(
return vec_reduce_all(red_fun, acc_vec); return vec_reduce_all(red_fun, acc_vec);
} }
template <typename scalar_t, typename MapOp, typename ReduceOp, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename MapOp,
typename ReduceOp,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map2_reduce_all( inline scalar_t map2_reduce_all(
const MapOp& map_fun, const MapOp& map_fun,
const ReduceOp& red_fun, const ReduceOp& red_fun,
@ -237,8 +272,11 @@ inline scalar_t map2_reduce_all(
return vec_reduce_all(red_fun, acc_vec); return vec_reduce_all(red_fun, acc_vec);
} }
template <typename scalar_t, typename MapOp, typename ReduceOp, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename MapOp,
typename ReduceOp,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map3_reduce_all( inline scalar_t map3_reduce_all(
const MapOp& map_fun, const MapOp& map_fun,
const ReduceOp& red_fun, const ReduceOp& red_fun,
@ -274,8 +312,10 @@ inline scalar_t map3_reduce_all(
return vec_reduce_all(red_fun, acc_vec); return vec_reduce_all(red_fun, acc_vec);
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map( inline void map(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -293,8 +333,10 @@ inline void map(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map2( inline void map2(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -317,8 +359,10 @@ inline void map2(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map3( inline void map3(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -344,8 +388,10 @@ inline void map3(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map4( inline void map4(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,

View File

@ -8,86 +8,120 @@
namespace at::vec { namespace at::vec {
// BFloat16 specification // BFloat16 specification
template <typename scalar_t> struct VecScalarType { using type = scalar_t; }; template <typename scalar_t>
template <> struct VecScalarType<BFloat16> { using type = float; }; struct VecScalarType {
template <> struct VecScalarType<Half> { using type = float; }; using type = scalar_t;
};
template <>
struct VecScalarType<BFloat16> {
using type = float;
};
template <>
struct VecScalarType<Half> {
using type = float;
};
// This is different from at::acc_type since we only need to specialize BFloat16 // This is different from at::acc_type since we only need to specialize BFloat16
template <typename scalar_t> template <typename scalar_t>
using vec_scalar_t = typename VecScalarType<scalar_t>::type; using vec_scalar_t = typename VecScalarType<scalar_t>::type;
// Vector conversion between float and bfloat16/half // Vector conversion between float and bfloat16/half
template <typename scalar_t, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(const Vectorized<scalar_t>&); typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(
const Vectorized<scalar_t>&);
template <> template <>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<BFloat16> (const Vectorized<BFloat16>& a) { inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<
BFloat16>(const Vectorized<BFloat16>& a) {
return convert_bfloat16_float(a); return convert_bfloat16_float(a);
} }
template <> template <>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half> (const Vectorized<Half>& a) { inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half>(
return convert_half_float(a); const Vectorized<Half>& a) {
return convert_half_float(a);
} }
template <typename scalar_t, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline Vectorized<scalar_t> convert_from_float(const Vectorized<float>&, const Vectorized<float>&); typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline Vectorized<scalar_t> convert_from_float(
const Vectorized<float>&,
const Vectorized<float>&);
template <> template <>
inline Vectorized<BFloat16> convert_from_float<BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) { inline Vectorized<BFloat16> convert_from_float<BFloat16>(
const Vectorized<float>& a,
const Vectorized<float>& b) {
return convert_float_bfloat16(a, b); return convert_float_bfloat16(a, b);
} }
template <> template <>
inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, const Vectorized<float>& b) { inline Vectorized<Half> convert_from_float<Half>(
const Vectorized<float>& a,
const Vectorized<float>& b) {
return convert_float_half(a, b); return convert_float_half(a, b);
} }
template <typename scalar_t, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline void load_to_float(const scalar_t *data, Vectorized<float> &out1, Vectorized<float> &out2); typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void load_to_float(
const scalar_t* data,
Vectorized<float>& out1,
Vectorized<float>& out2);
template <> template <>
inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out1, Vectorized<float> &out2) { inline void load_to_float<BFloat16>(
const BFloat16* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
load_fp32_from_bf16(data, out1, out2); load_fp32_from_bf16(data, out1, out2);
} }
template <> template <>
inline void load_to_float<Half> (const Half *data, Vectorized<float> &out1, Vectorized<float> &out2) { inline void load_to_float<Half>(
const Half* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
load_fp32_from_fp16(data, out1, out2); load_fp32_from_fp16(data, out1, out2);
} }
template <typename scalar_t, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline void load_to_float(const scalar_t *data, Vectorized<float> &out); typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void load_to_float(const scalar_t* data, Vectorized<float>& out);
template <> template <>
inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out) { inline void load_to_float<BFloat16>(
const BFloat16* data,
Vectorized<float>& out) {
load_fp32_from_bf16(data, out); load_fp32_from_bf16(data, out);
} }
template <> template <>
inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) { inline void load_to_float<Half>(const Half* data, Vectorized<float>& out) {
load_fp32_from_fp16(data, out); load_fp32_from_fp16(data, out);
} }
// Note that we already have specialized member of Vectorized<scalar_t> for BFloat16 // Note that we already have specialized member of Vectorized<scalar_t> for
// so the following functions would run smoothly: // BFloat16 so the following functions would run smoothly:
// using Vec = Vectorized<BFloat16>; // using Vec = Vectorized<BFloat16>;
// Vec one = Vec(BFloat16(1)); // Vec one = Vec(BFloat16(1));
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
// //
// Then why we still need to specialize "functional"? // Then why we still need to specialize "functional"?
// If we do specialization at Vectorized<> level, the above example would need 3 pairs of // If we do specialization at Vectorized<> level, the above example would need
// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". // 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and
// If we do specialization at vec::map<>() level, we have only 1 pair of conversion // "/". If we do specialization at vec::map<>() level, we have only 1 pair of
// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. // conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16
// vector only.
// //
// The following BFloat16 functionality will only do data type conversion for input // The following BFloat16 functionality will only do data type conversion for
// and output vector (reduce functionality will only convert the final scalar back to bf16). // input and output vector (reduce functionality will only convert the final
// Compared to Vectorized<> specialization, // scalar back to bf16). Compared to Vectorized<> specialization,
// 1. better performance since we have less data type conversion; // 1. better performance since we have less data type conversion;
// 2. less rounding error since immediate results are kept in fp32; // 2. less rounding error since immediate results are kept in fp32;
// 3. accumulation done on data type of fp32. // 3. accumulation done on data type of fp32.
@ -95,8 +129,10 @@ inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) {
// If you plan to extend this file, please ensure adding unit tests at // If you plan to extend this file, please ensure adding unit tests at
// aten/src/ATen/test/vec_test_all_types.cpp // aten/src/ATen/test/vec_test_all_types.cpp
// //
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>; using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>; using fVec = vec::Vectorized<float>;
@ -104,7 +140,8 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
bVec data_bvec = bVec::loadu(data, size); bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size > fVec::size()) { if (size > fVec::size()) {
data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); data_fvec0 = fVec::set(
data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size()); return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
} else { } else {
return vec_reduce_all<float>(vec_fun, data_fvec0, size); return vec_reduce_all<float>(vec_fun, data_fvec0, size);
@ -124,27 +161,37 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size - d > fVec::size()) { if (size - d > fVec::size()) {
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); acc_fvec1 = fVec::set(
acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else { } else {
acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); acc_fvec0 =
fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
} }
} }
acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(vec_fun, acc_fvec0); return vec_reduce_all<float>(vec_fun, acc_fvec0);
} }
template <typename scalar_t, typename Op1, typename Op2, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, typename Op1,
const scalar_t* data, int64_t size) { typename Op2,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::pair<float, float> reduce2_all(
const Op1& vec_fun1,
const Op2& vec_fun2,
const scalar_t* data,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>; using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>; using fVec = vec::Vectorized<float>;
if (size < bVec::size()) { if (size < bVec::size()) {
bVec data_bvec = bVec::loadu(data, size); bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size > fVec::size()) { if (size > fVec::size()) {
fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); fVec acc1_fvec = fVec::set(
fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
fVec acc2_fvec = fVec::set(
data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
return std::pair<scalar_t, scalar_t>( return std::pair<scalar_t, scalar_t>(
vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()), vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size())); vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
@ -171,12 +218,20 @@ inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_f
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec); auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size - d > fVec::size()) { if (size - d > fVec::size()) {
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size()); acc1_fvec1 = fVec::set(
acc1_fvec1,
vec_fun1(acc1_fvec1, data_fvec1),
size - d - fVec::size());
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size()); acc2_fvec1 = fVec::set(
acc2_fvec1,
vec_fun2(acc2_fvec1, data_fvec1),
size - d - fVec::size());
} else { } else {
acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); acc1_fvec0 =
acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
acc2_fvec0 =
fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
} }
} }
acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
@ -186,8 +241,11 @@ inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_f
vec_reduce_all<float>(vec_fun2, acc2_fvec0)); vec_reduce_all<float>(vec_fun2, acc2_fvec0));
} }
template <typename scalar_t, typename MapOp, typename ReduceOp, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename MapOp,
typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float map_reduce_all( inline float map_reduce_all(
const MapOp& map_fun, const MapOp& map_fun,
const ReduceOp& red_fun, const ReduceOp& red_fun,
@ -201,7 +259,8 @@ inline float map_reduce_all(
if (size > fVec::size()) { if (size > fVec::size()) {
data_fvec0 = map_fun(data_fvec0); data_fvec0 = map_fun(data_fvec0);
data_fvec1 = map_fun(data_fvec1); data_fvec1 = map_fun(data_fvec1);
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); data_fvec0 = fVec::set(
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size()); return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
} else { } else {
data_fvec0 = map_fun(data_fvec0); data_fvec0 = map_fun(data_fvec0);
@ -228,18 +287,23 @@ inline float map_reduce_all(
data_fvec0 = map_fun(data_fvec0); data_fvec0 = map_fun(data_fvec0);
data_fvec1 = map_fun(data_fvec1); data_fvec1 = map_fun(data_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0); acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); acc_fvec1 = fVec::set(
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else { } else {
data_fvec0 = map_fun(data_fvec0); data_fvec0 = map_fun(data_fvec0);
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); acc_fvec0 =
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
} }
} }
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(red_fun, acc_fvec0); return vec_reduce_all<float>(red_fun, acc_fvec0);
} }
template <typename scalar_t, typename MapOp, typename ReduceOp, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename MapOp,
typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float map2_reduce_all( inline float map2_reduce_all(
const MapOp& map_fun, const MapOp& map_fun,
const ReduceOp& red_fun, const ReduceOp& red_fun,
@ -256,7 +320,8 @@ inline float map2_reduce_all(
if (size > fVec::size()) { if (size > fVec::size()) {
data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1); data_fvec1 = map_fun(data_fvec1, data2_fvec1);
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); data_fvec0 = fVec::set(
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size()); return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
} else { } else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0);
@ -289,18 +354,23 @@ inline float map2_reduce_all(
data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1); data_fvec1 = map_fun(data_fvec1, data2_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0); acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); acc_fvec1 = fVec::set(
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else { } else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0);
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); acc_fvec0 =
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
} }
} }
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(red_fun, acc_fvec0); return vec_reduce_all<float>(red_fun, acc_fvec0);
} }
template <typename scalar_t, typename MapOp, typename ReduceOp, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename MapOp,
typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float map3_reduce_all( inline float map3_reduce_all(
const MapOp& map_fun, const MapOp& map_fun,
const ReduceOp& red_fun, const ReduceOp& red_fun,
@ -320,7 +390,8 @@ inline float map3_reduce_all(
if (size > fVec::size()) { if (size > fVec::size()) {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); data_fvec0 = fVec::set(
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size()); return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
} else { } else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
@ -359,18 +430,22 @@ inline float map3_reduce_all(
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0); acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); acc_fvec1 = fVec::set(
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else { } else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); acc_fvec0 =
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
} }
} }
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(red_fun, acc_fvec0); return vec_reduce_all<float>(red_fun, acc_fvec0);
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map( inline void map(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -397,8 +472,10 @@ inline void map(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map( inline void map(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -419,7 +496,8 @@ inline void map(
fVec data_fvec0, data_fvec1; fVec data_fvec0, data_fvec1;
if (size - d > fVec::size()) { if (size - d > fVec::size()) {
data_fvec0 = fVec::loadu(input_data + d); data_fvec0 = fVec::loadu(input_data + d);
data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); data_fvec1 =
fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
} else { } else {
// choose to align with behaviour of bVec::loadu(ptr, size), // choose to align with behaviour of bVec::loadu(ptr, size),
// which leaves data_fvec1 uninitialized // which leaves data_fvec1 uninitialized
@ -432,8 +510,10 @@ inline void map(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map2( inline void map2(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -465,8 +545,10 @@ inline void map2(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map3( inline void map3(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -503,8 +585,10 @@ inline void map3(
} }
} }
template <typename scalar_t, typename Op, template <
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> typename scalar_t,
typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map4( inline void map4(
const Op& vec_fun, const Op& vec_fun,
scalar_t* output_data, scalar_t* output_data,
@ -525,8 +609,10 @@ inline void map4(
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
bVec data4_bvec = bVec::loadu(input_data4 + d); bVec data4_bvec = bVec::loadu(input_data4 + d);
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec); auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); fVec output_fvec0 =
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
fVec output_fvec1 =
vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d); output_bvec.store(output_data + d);
} }
@ -539,8 +625,10 @@ inline void map4(
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec); auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec); auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); fVec output_fvec0 =
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
fVec output_fvec1 =
vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1); bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d, size - d); output_bvec.store(output_data + d, size - d);
} }

View File

@ -13,10 +13,14 @@
/* Microsoft C/C++-compatible compiler */ /* Microsoft C/C++-compatible compiler */
#include <intrin.h> #include <intrin.h>
#if _MSC_VER <= 1900 #if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) #define _mm256_extract_epi64(X, Y) \
#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) #define _mm256_extract_epi32(X, Y) \
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi16(X, Y) \
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) \
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif #endif
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */ /* GCC-compatible compiler, targeting ARM with NEON */
@ -25,9 +29,9 @@
/* GCC-compatible compiler, targeting ARM with SVE */ /* GCC-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h> #include <arm_sve.h>
#endif #endif
#if defined (MISSING_ARM_VLD1) #if defined(MISSING_ARM_VLD1)
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h> #include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined (MISSING_ARM_VST1) #elif defined(MISSING_ARM_VST1)
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h> #include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
#endif #endif
#elif defined(__GNUC__) && defined(__IWMMXT__) #elif defined(__GNUC__) && defined(__IWMMXT__)
@ -36,8 +40,8 @@
#elif defined(__s390x__) #elif defined(__s390x__)
// targets Z/architecture // targets Z/architecture
// we will include vecintrin later // we will include vecintrin later
#elif (defined(__GNUC__) || defined(__xlC__)) && \ #elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__)) (defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h> #include <altivec.h>
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts /* We need to undef those tokens defined by <altivec.h> to avoid conflicts

View File

@ -28,21 +28,30 @@ inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
} }
template <> template <>
inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) { inline Vectorized<bool> Vectorized<bool>::loadu(
const void* ptr,
int64_t count) {
// See NOTE [Loading boolean values] // See NOTE [Loading boolean values]
return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count)); return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
} }
template <typename VT> template <typename VT>
struct VecHoldType { using hold_type = typename VT::value_type; }; struct VecHoldType {
using hold_type = typename VT::value_type;
};
template <> template <>
struct VecHoldType<Vectorized<BFloat16>> { using hold_type = BFloat16; }; struct VecHoldType<Vectorized<BFloat16>> {
using hold_type = BFloat16;
};
template <> template <>
struct VecHoldType<Vectorized<Half>> {using hold_type = Half; }; struct VecHoldType<Vectorized<Half>> {
using hold_type = Half;
};
template <typename VT> template <typename VT>
using vechold_type = typename VecHoldType<VT>::hold_type; using vechold_type = typename VecHoldType<VT>::hold_type;
}} // namespace at::vec::CPU_CAPABILITY } // namespace CPU_CAPABILITY
} // namespace at::vec

File diff suppressed because it is too large Load Diff

View File

@ -28,8 +28,8 @@ struct VecConvert {
}; };
template <typename dst_t, typename src_t> template <typename dst_t, typename src_t>
inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>> inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>> convert(
convert(const Vectorized<src_t>& src) { const Vectorized<src_t>& src) {
return src; return src;
} }

View File

@ -103,7 +103,9 @@ static inline void transpose_pad_2x32_block(
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
} }
#else #else
TORCH_CHECK(false, "transpose_pad_2x32_block is only supported when avx512 is supported") TORCH_CHECK(
false,
"transpose_pad_2x32_block is only supported when avx512 is supported")
#endif #endif
} }
@ -124,28 +126,31 @@ static inline void pack_vnni2(
for (; bk < _K; bk += 2) { for (; bk < _K; bk += 2) {
int64_t bn = 0; int64_t bn = 0;
for (; bn < _N; bn += 32) { for (; bn < _N; bn += 32) {
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); transpose_pad_2x32_block(
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src);
} }
int64_t nrem = N - bn; int64_t nrem = N - bn;
if (nrem > 0) { if (nrem > 0) {
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); transpose_pad_2x32_block(
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem);
} }
} }
if (K % 2 == 1) { if (K % 2 == 1) {
int64_t bn = 0; int64_t bn = 0;
for (; bn < _N; bn += 32) { for (; bn < _N; bn += 32) {
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); transpose_pad_2x32_block(
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1);
} }
int64_t nrem = N - bn; int64_t nrem = N - bn;
if (nrem > 0) { if (nrem > 0) {
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); transpose_pad_2x32_block(
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem);
} }
} }
#else #else
TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported")
#endif #endif
} }
} // namespace CPU_CAPABILITY } // namespace CPU_CAPABILITY
} // namespace at::vec } // namespace at::vec

View File

@ -68,7 +68,12 @@ struct VecMaskTo {
} }
}; };
template <typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void> template <
typename dst_t,
int dst_n,
typename src_t,
int src_n,
typename Enabled = void>
struct VecMaskCast { struct VecMaskCast {
static inline VecMask<dst_t, dst_n> apply( static inline VecMask<dst_t, dst_n> apply(
const VecMask<src_t, src_n>& vec_mask) { const VecMask<src_t, src_n>& vec_mask) {
@ -88,15 +93,17 @@ struct VecMaskCheck {
static inline bool all_zero(const VectorizedN<T, N>& vec_mask) { static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
__at_align__ T mask[VectorizedN<T, N>::size()]; __at_align__ T mask[VectorizedN<T, N>::size()];
vec_mask.store(mask); vec_mask.store(mask);
return std::all_of( return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
mask, mask + VectorizedN<T, N>::size(), [](T m) { return m == static_cast<T>(0); }); return m == static_cast<T>(0);
});
} }
static inline bool all_masked(const VectorizedN<T, N>& vec_mask) { static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
__at_align__ T mask[VectorizedN<T, N>::size()]; __at_align__ T mask[VectorizedN<T, N>::size()];
vec_mask.store(mask); vec_mask.store(mask);
return std::all_of( return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
mask, mask + VectorizedN<T, N>::size(), [](T m) { return m != static_cast<T>(0); }); return m != static_cast<T>(0);
});
} }
static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) { static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
@ -159,13 +166,11 @@ class VecMask {
} }
static VecMask<T, N> blendv( static VecMask<T, N> blendv(
const VecMask<T, N>& c, const VecMask<T, N>& c,
const VecMask<T, N>& b, const VecMask<T, N>& b,
const VecMask<T, N>& a) { const VecMask<T, N>& a) {
VectorizedN<T, N> result = VectorizedN<T, N>::blendv( VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
VectorizedN<T, N>(c), VectorizedN<T, N>(c), VectorizedN<T, N>(b), VectorizedN<T, N>(a));
VectorizedN<T, N>(b),
VectorizedN<T, N>(a));
return result; return result;
} }
@ -174,14 +179,14 @@ class VecMask {
const VecMask<T, N>& b, const VecMask<T, N>& b,
int64_t count = size()) { int64_t count = size()) {
VectorizedN<T, N> result = VectorizedN<T, N>::set( VectorizedN<T, N> result = VectorizedN<T, N>::set(
VectorizedN<T, N>(a), VectorizedN<T, N>(a), VectorizedN<T, N>(b), count);
VectorizedN<T, N>(b),
count);
return result; return result;
} }
void store(bool* b, int count = size()) { void store(bool* b, int count = size()) {
constexpr int L = (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1)/ Vectorized<bool>::size(); constexpr int L =
(VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1) /
Vectorized<bool>::size();
auto res = this->to<bool, L>(); auto res = this->to<bool, L>();
res.store(b, count); res.store(b, count);
return; return;