mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
clang-format aten/src/ATen/cpu/vec/*.h (#150426)
I got a complaint about indentation on #150380. Make the machines fix it for us. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150426 Approved by: https://github.com/aditew01, https://github.com/cyyever, https://github.com/frost-intel, https://github.com/Skylion007
This commit is contained in:
parent
bd9c42ebfb
commit
ed0fd2fa7a
|
|
@ -55,6 +55,7 @@ init_command = [
|
||||||
code = 'CLANGFORMAT'
|
code = 'CLANGFORMAT'
|
||||||
include_patterns = [
|
include_patterns = [
|
||||||
'aten/src/ATen/*.h',
|
'aten/src/ATen/*.h',
|
||||||
|
'aten/src/ATen/cpu/vec/*.h',
|
||||||
'aten/src/ATen/mps/**/*.mm',
|
'aten/src/ATen/mps/**/*.mm',
|
||||||
'aten/src/ATen/mps/**/*.h',
|
'aten/src/ATen/mps/**/*.h',
|
||||||
'aten/src/ATen/xpu/**/*.h',
|
'aten/src/ATen/xpu/**/*.h',
|
||||||
|
|
|
||||||
|
|
@ -29,16 +29,21 @@ inline scalar_t vec_reduce_all(
|
||||||
|
|
||||||
template <typename scalar_t, typename Op>
|
template <typename scalar_t, typename Op>
|
||||||
struct VecReduceAllSIMD {
|
struct VecReduceAllSIMD {
|
||||||
static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
|
static inline scalar_t apply(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const Vectorized<scalar_t>& acc_vec) {
|
||||||
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
|
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
|
#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \
|
||||||
|
!defined(C10_MOBILE)
|
||||||
#if defined(CPU_CAPABILITY_AVX2)
|
#if defined(CPU_CAPABILITY_AVX2)
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct VecReduceAllSIMD<float, Op> {
|
struct VecReduceAllSIMD<float, Op> {
|
||||||
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
|
static inline float apply(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const Vectorized<float>& acc_vec) {
|
||||||
using Vec = Vectorized<float>;
|
using Vec = Vectorized<float>;
|
||||||
Vec v = acc_vec;
|
Vec v = acc_vec;
|
||||||
// 128-bit shuffle
|
// 128-bit shuffle
|
||||||
|
|
@ -57,7 +62,9 @@ struct VecReduceAllSIMD<float, Op> {
|
||||||
#if defined(CPU_CAPABILITY_AVX512)
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct VecReduceAllSIMD<float, Op> {
|
struct VecReduceAllSIMD<float, Op> {
|
||||||
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
|
static inline float apply(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const Vectorized<float>& acc_vec) {
|
||||||
using Vec = Vectorized<float>;
|
using Vec = Vectorized<float>;
|
||||||
Vec v = acc_vec;
|
Vec v = acc_vec;
|
||||||
// 256-bit shuffle
|
// 256-bit shuffle
|
||||||
|
|
@ -76,25 +83,33 @@ struct VecReduceAllSIMD<float, Op> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif // defined(CPU_CAPABILITY_AVX512)
|
#endif // defined(CPU_CAPABILITY_AVX512)
|
||||||
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
|
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
|
||||||
|
// !defined(C10_MOBILE)
|
||||||
|
|
||||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE)
|
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||||
|
!defined(CPU_CAPABILITY_SVE)
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct VecReduceAllSIMD<float, Op> {
|
struct VecReduceAllSIMD<float, Op> {
|
||||||
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
|
static inline float apply(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const Vectorized<float>& acc_vec) {
|
||||||
using Vec = Vectorized<float>;
|
using Vec = Vectorized<float>;
|
||||||
Vec v = acc_vec;
|
Vec v = acc_vec;
|
||||||
|
|
||||||
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
|
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7,
|
||||||
|
// a4+a8, a1+a5, a2+a6, -, -, -, -]
|
||||||
float32x4_t v1_1 = vextq_f32(v, v, 2);
|
float32x4_t v1_1 = vextq_f32(v, v, 2);
|
||||||
Vec v1 = v1_1;
|
Vec v1 = v1_1;
|
||||||
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
|
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
|
||||||
v = vec_fun(v, v1);
|
v = vec_fun(v, v1);
|
||||||
|
|
||||||
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
|
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -,
|
||||||
|
// -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -,
|
||||||
|
// -]
|
||||||
v1_1 = vrev64q_f32(v);
|
v1_1 = vrev64q_f32(v);
|
||||||
v1 = v1_1;
|
v1 = v1_1;
|
||||||
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
|
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8,
|
||||||
|
// a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
|
||||||
v = vec_fun(v, v1);
|
v = vec_fun(v, v1);
|
||||||
|
|
||||||
return v[0];
|
return v[0];
|
||||||
|
|
@ -102,10 +117,13 @@ struct VecReduceAllSIMD<float, Op> {
|
||||||
};
|
};
|
||||||
#endif // defined(__aarch64__)
|
#endif // defined(__aarch64__)
|
||||||
|
|
||||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256)
|
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||||
|
defined(CPU_CAPABILITY_SVE256)
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct VecReduceAllSIMD<float, Op> {
|
struct VecReduceAllSIMD<float, Op> {
|
||||||
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
|
static inline float apply(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const Vectorized<float>& acc_vec) {
|
||||||
using Vec = Vectorized<float>;
|
using Vec = Vectorized<float>;
|
||||||
Vec v = acc_vec;
|
Vec v = acc_vec;
|
||||||
// 128-bit shuffle
|
// 128-bit shuffle
|
||||||
|
|
@ -125,15 +143,21 @@ struct VecReduceAllSIMD<float, Op> {
|
||||||
};
|
};
|
||||||
#endif // defined(__aarch64__)
|
#endif // defined(__aarch64__)
|
||||||
|
|
||||||
|
|
||||||
template <typename scalar_t, typename Op>
|
template <typename scalar_t, typename Op>
|
||||||
inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
|
inline scalar_t vec_reduce_all(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const Vectorized<scalar_t>& acc_vec) {
|
||||||
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
|
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
typename Op,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline scalar_t reduce_all(
|
||||||
|
const Op& vec_fun,
|
||||||
|
const scalar_t* data,
|
||||||
|
int64_t size) {
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
if (size < Vec::size())
|
if (size < Vec::size())
|
||||||
return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
|
return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
|
||||||
|
|
@ -151,16 +175,22 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size
|
||||||
}
|
}
|
||||||
|
|
||||||
// similar to reduce_all, but reduces into two outputs
|
// similar to reduce_all, but reduces into two outputs
|
||||||
template <typename scalar_t, typename Op1, typename Op2,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
|
typename Op1,
|
||||||
const scalar_t* data, int64_t size) {
|
typename Op2,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline std::pair<scalar_t, scalar_t> reduce2_all(
|
||||||
|
const Op1& vec_fun1,
|
||||||
|
const Op2& vec_fun2,
|
||||||
|
const scalar_t* data,
|
||||||
|
int64_t size) {
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
if (size < Vec::size()) {
|
if (size < Vec::size()) {
|
||||||
auto loaded_data = Vec::loadu(data, size);
|
auto loaded_data = Vec::loadu(data, size);
|
||||||
return std::pair<scalar_t, scalar_t>(
|
return std::pair<scalar_t, scalar_t>(
|
||||||
vec_reduce_all(vec_fun1, loaded_data, size),
|
vec_reduce_all(vec_fun1, loaded_data, size),
|
||||||
vec_reduce_all(vec_fun2, loaded_data, size));
|
vec_reduce_all(vec_fun2, loaded_data, size));
|
||||||
}
|
}
|
||||||
int64_t d = Vec::size();
|
int64_t d = Vec::size();
|
||||||
Vec acc_vec1 = Vec::loadu(data);
|
Vec acc_vec1 = Vec::loadu(data);
|
||||||
|
|
@ -176,12 +206,14 @@ inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2&
|
||||||
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
|
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
|
||||||
}
|
}
|
||||||
return std::pair<scalar_t, scalar_t>(
|
return std::pair<scalar_t, scalar_t>(
|
||||||
vec_reduce_all(vec_fun1, acc_vec1),
|
vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2));
|
||||||
vec_reduce_all(vec_fun2, acc_vec2));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename MapOp,
|
||||||
|
typename ReduceOp,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline scalar_t map_reduce_all(
|
inline scalar_t map_reduce_all(
|
||||||
const MapOp& map_fun,
|
const MapOp& map_fun,
|
||||||
const ReduceOp& red_fun,
|
const ReduceOp& red_fun,
|
||||||
|
|
@ -205,8 +237,11 @@ inline scalar_t map_reduce_all(
|
||||||
return vec_reduce_all(red_fun, acc_vec);
|
return vec_reduce_all(red_fun, acc_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename MapOp,
|
||||||
|
typename ReduceOp,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline scalar_t map2_reduce_all(
|
inline scalar_t map2_reduce_all(
|
||||||
const MapOp& map_fun,
|
const MapOp& map_fun,
|
||||||
const ReduceOp& red_fun,
|
const ReduceOp& red_fun,
|
||||||
|
|
@ -237,8 +272,11 @@ inline scalar_t map2_reduce_all(
|
||||||
return vec_reduce_all(red_fun, acc_vec);
|
return vec_reduce_all(red_fun, acc_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename MapOp,
|
||||||
|
typename ReduceOp,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline scalar_t map3_reduce_all(
|
inline scalar_t map3_reduce_all(
|
||||||
const MapOp& map_fun,
|
const MapOp& map_fun,
|
||||||
const ReduceOp& red_fun,
|
const ReduceOp& red_fun,
|
||||||
|
|
@ -274,8 +312,10 @@ inline scalar_t map3_reduce_all(
|
||||||
return vec_reduce_all(red_fun, acc_vec);
|
return vec_reduce_all(red_fun, acc_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map(
|
inline void map(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -293,8 +333,10 @@ inline void map(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map2(
|
inline void map2(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -317,8 +359,10 @@ inline void map2(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map3(
|
inline void map3(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -344,8 +388,10 @@ inline void map3(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map4(
|
inline void map4(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
|
||||||
|
|
@ -8,86 +8,120 @@
|
||||||
namespace at::vec {
|
namespace at::vec {
|
||||||
|
|
||||||
// BFloat16 specification
|
// BFloat16 specification
|
||||||
template <typename scalar_t> struct VecScalarType { using type = scalar_t; };
|
template <typename scalar_t>
|
||||||
template <> struct VecScalarType<BFloat16> { using type = float; };
|
struct VecScalarType {
|
||||||
template <> struct VecScalarType<Half> { using type = float; };
|
using type = scalar_t;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct VecScalarType<BFloat16> {
|
||||||
|
using type = float;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct VecScalarType<Half> {
|
||||||
|
using type = float;
|
||||||
|
};
|
||||||
|
|
||||||
// This is different from at::acc_type since we only need to specialize BFloat16
|
// This is different from at::acc_type since we only need to specialize BFloat16
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
using vec_scalar_t = typename VecScalarType<scalar_t>::type;
|
using vec_scalar_t = typename VecScalarType<scalar_t>::type;
|
||||||
|
|
||||||
// Vector conversion between float and bfloat16/half
|
// Vector conversion between float and bfloat16/half
|
||||||
template <typename scalar_t,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(const Vectorized<scalar_t>&);
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(
|
||||||
|
const Vectorized<scalar_t>&);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<BFloat16> (const Vectorized<BFloat16>& a) {
|
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<
|
||||||
|
BFloat16>(const Vectorized<BFloat16>& a) {
|
||||||
return convert_bfloat16_float(a);
|
return convert_bfloat16_float(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half> (const Vectorized<Half>& a) {
|
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half>(
|
||||||
return convert_half_float(a);
|
const Vectorized<Half>& a) {
|
||||||
|
return convert_half_float(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline Vectorized<scalar_t> convert_from_float(const Vectorized<float>&, const Vectorized<float>&);
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline Vectorized<scalar_t> convert_from_float(
|
||||||
|
const Vectorized<float>&,
|
||||||
|
const Vectorized<float>&);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Vectorized<BFloat16> convert_from_float<BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
inline Vectorized<BFloat16> convert_from_float<BFloat16>(
|
||||||
|
const Vectorized<float>& a,
|
||||||
|
const Vectorized<float>& b) {
|
||||||
return convert_float_bfloat16(a, b);
|
return convert_float_bfloat16(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
inline Vectorized<Half> convert_from_float<Half>(
|
||||||
|
const Vectorized<float>& a,
|
||||||
|
const Vectorized<float>& b) {
|
||||||
return convert_float_half(a, b);
|
return convert_float_half(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline void load_to_float(const scalar_t *data, Vectorized<float> &out1, Vectorized<float> &out2);
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline void load_to_float(
|
||||||
|
const scalar_t* data,
|
||||||
|
Vectorized<float>& out1,
|
||||||
|
Vectorized<float>& out2);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out1, Vectorized<float> &out2) {
|
inline void load_to_float<BFloat16>(
|
||||||
|
const BFloat16* data,
|
||||||
|
Vectorized<float>& out1,
|
||||||
|
Vectorized<float>& out2) {
|
||||||
load_fp32_from_bf16(data, out1, out2);
|
load_fp32_from_bf16(data, out1, out2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline void load_to_float<Half> (const Half *data, Vectorized<float> &out1, Vectorized<float> &out2) {
|
inline void load_to_float<Half>(
|
||||||
|
const Half* data,
|
||||||
|
Vectorized<float>& out1,
|
||||||
|
Vectorized<float>& out2) {
|
||||||
load_fp32_from_fp16(data, out1, out2);
|
load_fp32_from_fp16(data, out1, out2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline void load_to_float(const scalar_t *data, Vectorized<float> &out);
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline void load_to_float(const scalar_t* data, Vectorized<float>& out);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out) {
|
inline void load_to_float<BFloat16>(
|
||||||
|
const BFloat16* data,
|
||||||
|
Vectorized<float>& out) {
|
||||||
load_fp32_from_bf16(data, out);
|
load_fp32_from_bf16(data, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) {
|
inline void load_to_float<Half>(const Half* data, Vectorized<float>& out) {
|
||||||
load_fp32_from_fp16(data, out);
|
load_fp32_from_fp16(data, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that we already have specialized member of Vectorized<scalar_t> for BFloat16
|
// Note that we already have specialized member of Vectorized<scalar_t> for
|
||||||
// so the following functions would run smoothly:
|
// BFloat16 so the following functions would run smoothly:
|
||||||
// using Vec = Vectorized<BFloat16>;
|
// using Vec = Vectorized<BFloat16>;
|
||||||
// Vec one = Vec(BFloat16(1));
|
// Vec one = Vec(BFloat16(1));
|
||||||
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
|
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
|
||||||
//
|
//
|
||||||
// Then why we still need to specialize "functional"?
|
// Then why we still need to specialize "functional"?
|
||||||
// If we do specialization at Vectorized<> level, the above example would need 3 pairs of
|
// If we do specialization at Vectorized<> level, the above example would need
|
||||||
// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
|
// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and
|
||||||
// If we do specialization at vec::map<>() level, we have only 1 pair of conversion
|
// "/". If we do specialization at vec::map<>() level, we have only 1 pair of
|
||||||
// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
|
// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16
|
||||||
|
// vector only.
|
||||||
//
|
//
|
||||||
// The following BFloat16 functionality will only do data type conversion for input
|
// The following BFloat16 functionality will only do data type conversion for
|
||||||
// and output vector (reduce functionality will only convert the final scalar back to bf16).
|
// input and output vector (reduce functionality will only convert the final
|
||||||
// Compared to Vectorized<> specialization,
|
// scalar back to bf16). Compared to Vectorized<> specialization,
|
||||||
// 1. better performance since we have less data type conversion;
|
// 1. better performance since we have less data type conversion;
|
||||||
// 2. less rounding error since immediate results are kept in fp32;
|
// 2. less rounding error since immediate results are kept in fp32;
|
||||||
// 3. accumulation done on data type of fp32.
|
// 3. accumulation done on data type of fp32.
|
||||||
|
|
@ -95,8 +129,10 @@ inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) {
|
||||||
// If you plan to extend this file, please ensure adding unit tests at
|
// If you plan to extend this file, please ensure adding unit tests at
|
||||||
// aten/src/ATen/test/vec_test_all_types.cpp
|
// aten/src/ATen/test/vec_test_all_types.cpp
|
||||||
//
|
//
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
||||||
using bVec = vec::Vectorized<scalar_t>;
|
using bVec = vec::Vectorized<scalar_t>;
|
||||||
using fVec = vec::Vectorized<float>;
|
using fVec = vec::Vectorized<float>;
|
||||||
|
|
@ -104,7 +140,8 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
||||||
bVec data_bvec = bVec::loadu(data, size);
|
bVec data_bvec = bVec::loadu(data, size);
|
||||||
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
||||||
if (size > fVec::size()) {
|
if (size > fVec::size()) {
|
||||||
data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
|
data_fvec0 = fVec::set(
|
||||||
|
data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
|
||||||
return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
|
return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
|
||||||
} else {
|
} else {
|
||||||
return vec_reduce_all<float>(vec_fun, data_fvec0, size);
|
return vec_reduce_all<float>(vec_fun, data_fvec0, size);
|
||||||
|
|
@ -124,27 +161,37 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
||||||
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
||||||
if (size - d > fVec::size()) {
|
if (size - d > fVec::size()) {
|
||||||
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
|
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
|
||||||
acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
acc_fvec1 = fVec::set(
|
||||||
|
acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
||||||
} else {
|
} else {
|
||||||
acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
|
acc_fvec0 =
|
||||||
|
fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
|
acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
|
||||||
return vec_reduce_all<float>(vec_fun, acc_fvec0);
|
return vec_reduce_all<float>(vec_fun, acc_fvec0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op1, typename Op2,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
|
typename Op1,
|
||||||
const scalar_t* data, int64_t size) {
|
typename Op2,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
|
inline std::pair<float, float> reduce2_all(
|
||||||
|
const Op1& vec_fun1,
|
||||||
|
const Op2& vec_fun2,
|
||||||
|
const scalar_t* data,
|
||||||
|
int64_t size) {
|
||||||
using bVec = vec::Vectorized<scalar_t>;
|
using bVec = vec::Vectorized<scalar_t>;
|
||||||
using fVec = vec::Vectorized<float>;
|
using fVec = vec::Vectorized<float>;
|
||||||
if (size < bVec::size()) {
|
if (size < bVec::size()) {
|
||||||
bVec data_bvec = bVec::loadu(data, size);
|
bVec data_bvec = bVec::loadu(data, size);
|
||||||
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
||||||
if (size > fVec::size()) {
|
if (size > fVec::size()) {
|
||||||
fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
|
fVec acc1_fvec = fVec::set(
|
||||||
fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
|
data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
|
||||||
|
fVec acc2_fvec = fVec::set(
|
||||||
|
data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
|
||||||
return std::pair<scalar_t, scalar_t>(
|
return std::pair<scalar_t, scalar_t>(
|
||||||
vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
|
vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
|
||||||
vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
|
vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
|
||||||
|
|
@ -171,12 +218,20 @@ inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_f
|
||||||
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
||||||
if (size - d > fVec::size()) {
|
if (size - d > fVec::size()) {
|
||||||
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
|
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
|
||||||
acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size());
|
acc1_fvec1 = fVec::set(
|
||||||
|
acc1_fvec1,
|
||||||
|
vec_fun1(acc1_fvec1, data_fvec1),
|
||||||
|
size - d - fVec::size());
|
||||||
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
|
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
|
||||||
acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size());
|
acc2_fvec1 = fVec::set(
|
||||||
|
acc2_fvec1,
|
||||||
|
vec_fun2(acc2_fvec1, data_fvec1),
|
||||||
|
size - d - fVec::size());
|
||||||
} else {
|
} else {
|
||||||
acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
|
acc1_fvec0 =
|
||||||
acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
|
fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
|
||||||
|
acc2_fvec0 =
|
||||||
|
fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
|
acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
|
||||||
|
|
@ -186,8 +241,11 @@ inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_f
|
||||||
vec_reduce_all<float>(vec_fun2, acc2_fvec0));
|
vec_reduce_all<float>(vec_fun2, acc2_fvec0));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename MapOp,
|
||||||
|
typename ReduceOp,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline float map_reduce_all(
|
inline float map_reduce_all(
|
||||||
const MapOp& map_fun,
|
const MapOp& map_fun,
|
||||||
const ReduceOp& red_fun,
|
const ReduceOp& red_fun,
|
||||||
|
|
@ -201,7 +259,8 @@ inline float map_reduce_all(
|
||||||
if (size > fVec::size()) {
|
if (size > fVec::size()) {
|
||||||
data_fvec0 = map_fun(data_fvec0);
|
data_fvec0 = map_fun(data_fvec0);
|
||||||
data_fvec1 = map_fun(data_fvec1);
|
data_fvec1 = map_fun(data_fvec1);
|
||||||
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
data_fvec0 = fVec::set(
|
||||||
|
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
||||||
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
||||||
} else {
|
} else {
|
||||||
data_fvec0 = map_fun(data_fvec0);
|
data_fvec0 = map_fun(data_fvec0);
|
||||||
|
|
@ -228,18 +287,23 @@ inline float map_reduce_all(
|
||||||
data_fvec0 = map_fun(data_fvec0);
|
data_fvec0 = map_fun(data_fvec0);
|
||||||
data_fvec1 = map_fun(data_fvec1);
|
data_fvec1 = map_fun(data_fvec1);
|
||||||
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
||||||
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
acc_fvec1 = fVec::set(
|
||||||
|
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
||||||
} else {
|
} else {
|
||||||
data_fvec0 = map_fun(data_fvec0);
|
data_fvec0 = map_fun(data_fvec0);
|
||||||
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
acc_fvec0 =
|
||||||
|
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
||||||
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename MapOp,
|
||||||
|
typename ReduceOp,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline float map2_reduce_all(
|
inline float map2_reduce_all(
|
||||||
const MapOp& map_fun,
|
const MapOp& map_fun,
|
||||||
const ReduceOp& red_fun,
|
const ReduceOp& red_fun,
|
||||||
|
|
@ -256,7 +320,8 @@ inline float map2_reduce_all(
|
||||||
if (size > fVec::size()) {
|
if (size > fVec::size()) {
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
||||||
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
||||||
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
data_fvec0 = fVec::set(
|
||||||
|
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
||||||
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
||||||
} else {
|
} else {
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
||||||
|
|
@ -289,18 +354,23 @@ inline float map2_reduce_all(
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
||||||
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
||||||
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
||||||
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
acc_fvec1 = fVec::set(
|
||||||
|
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
||||||
} else {
|
} else {
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
||||||
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
acc_fvec0 =
|
||||||
|
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
||||||
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename MapOp,
|
||||||
|
typename ReduceOp,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline float map3_reduce_all(
|
inline float map3_reduce_all(
|
||||||
const MapOp& map_fun,
|
const MapOp& map_fun,
|
||||||
const ReduceOp& red_fun,
|
const ReduceOp& red_fun,
|
||||||
|
|
@ -320,7 +390,8 @@ inline float map3_reduce_all(
|
||||||
if (size > fVec::size()) {
|
if (size > fVec::size()) {
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
||||||
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
||||||
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
data_fvec0 = fVec::set(
|
||||||
|
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
||||||
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
||||||
} else {
|
} else {
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
||||||
|
|
@ -359,18 +430,22 @@ inline float map3_reduce_all(
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
||||||
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
||||||
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
||||||
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
acc_fvec1 = fVec::set(
|
||||||
|
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
||||||
} else {
|
} else {
|
||||||
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
||||||
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
acc_fvec0 =
|
||||||
|
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
||||||
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map(
|
inline void map(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -397,8 +472,10 @@ inline void map(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map(
|
inline void map(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -419,7 +496,8 @@ inline void map(
|
||||||
fVec data_fvec0, data_fvec1;
|
fVec data_fvec0, data_fvec1;
|
||||||
if (size - d > fVec::size()) {
|
if (size - d > fVec::size()) {
|
||||||
data_fvec0 = fVec::loadu(input_data + d);
|
data_fvec0 = fVec::loadu(input_data + d);
|
||||||
data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
|
data_fvec1 =
|
||||||
|
fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
|
||||||
} else {
|
} else {
|
||||||
// choose to align with behaviour of bVec::loadu(ptr, size),
|
// choose to align with behaviour of bVec::loadu(ptr, size),
|
||||||
// which leaves data_fvec1 uninitialized
|
// which leaves data_fvec1 uninitialized
|
||||||
|
|
@ -432,8 +510,10 @@ inline void map(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map2(
|
inline void map2(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -465,8 +545,10 @@ inline void map2(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map3(
|
inline void map3(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -503,8 +585,10 @@ inline void map3(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename Op,
|
template <
|
||||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
typename scalar_t,
|
||||||
|
typename Op,
|
||||||
|
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||||
inline void map4(
|
inline void map4(
|
||||||
const Op& vec_fun,
|
const Op& vec_fun,
|
||||||
scalar_t* output_data,
|
scalar_t* output_data,
|
||||||
|
|
@ -525,8 +609,10 @@ inline void map4(
|
||||||
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
||||||
bVec data4_bvec = bVec::loadu(input_data4 + d);
|
bVec data4_bvec = bVec::loadu(input_data4 + d);
|
||||||
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
|
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
|
||||||
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
|
fVec output_fvec0 =
|
||||||
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
|
vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
|
||||||
|
fVec output_fvec1 =
|
||||||
|
vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
|
||||||
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
||||||
output_bvec.store(output_data + d);
|
output_bvec.store(output_data + d);
|
||||||
}
|
}
|
||||||
|
|
@ -539,8 +625,10 @@ inline void map4(
|
||||||
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
||||||
bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
|
bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
|
||||||
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
|
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
|
||||||
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
|
fVec output_fvec0 =
|
||||||
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
|
vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
|
||||||
|
fVec output_fvec1 =
|
||||||
|
vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
|
||||||
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
||||||
output_bvec.store(output_data + d, size - d);
|
output_bvec.store(output_data + d, size - d);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,14 @@
|
||||||
/* Microsoft C/C++-compatible compiler */
|
/* Microsoft C/C++-compatible compiler */
|
||||||
#include <intrin.h>
|
#include <intrin.h>
|
||||||
#if _MSC_VER <= 1900
|
#if _MSC_VER <= 1900
|
||||||
#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
|
#define _mm256_extract_epi64(X, Y) \
|
||||||
#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
|
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
|
||||||
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
#define _mm256_extract_epi32(X, Y) \
|
||||||
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
|
||||||
|
#define _mm256_extract_epi16(X, Y) \
|
||||||
|
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
||||||
|
#define _mm256_extract_epi8(X, Y) \
|
||||||
|
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
||||||
#endif
|
#endif
|
||||||
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||||
/* GCC-compatible compiler, targeting ARM with NEON */
|
/* GCC-compatible compiler, targeting ARM with NEON */
|
||||||
|
|
@ -25,9 +29,9 @@
|
||||||
/* GCC-compatible compiler, targeting ARM with SVE */
|
/* GCC-compatible compiler, targeting ARM with SVE */
|
||||||
#include <arm_sve.h>
|
#include <arm_sve.h>
|
||||||
#endif
|
#endif
|
||||||
#if defined (MISSING_ARM_VLD1)
|
#if defined(MISSING_ARM_VLD1)
|
||||||
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
|
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
|
||||||
#elif defined (MISSING_ARM_VST1)
|
#elif defined(MISSING_ARM_VST1)
|
||||||
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
|
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
|
||||||
#endif
|
#endif
|
||||||
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
||||||
|
|
@ -36,8 +40,8 @@
|
||||||
#elif defined(__s390x__)
|
#elif defined(__s390x__)
|
||||||
// targets Z/architecture
|
// targets Z/architecture
|
||||||
// we will include vecintrin later
|
// we will include vecintrin later
|
||||||
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
||||||
(defined(__VEC__) || defined(__ALTIVEC__))
|
(defined(__VEC__) || defined(__ALTIVEC__))
|
||||||
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
||||||
#include <altivec.h>
|
#include <altivec.h>
|
||||||
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
||||||
|
|
|
||||||
|
|
@ -28,21 +28,30 @@ inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) {
|
inline Vectorized<bool> Vectorized<bool>::loadu(
|
||||||
|
const void* ptr,
|
||||||
|
int64_t count) {
|
||||||
// See NOTE [Loading boolean values]
|
// See NOTE [Loading boolean values]
|
||||||
return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
|
return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename VT>
|
template <typename VT>
|
||||||
struct VecHoldType { using hold_type = typename VT::value_type; };
|
struct VecHoldType {
|
||||||
|
using hold_type = typename VT::value_type;
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct VecHoldType<Vectorized<BFloat16>> { using hold_type = BFloat16; };
|
struct VecHoldType<Vectorized<BFloat16>> {
|
||||||
|
using hold_type = BFloat16;
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct VecHoldType<Vectorized<Half>> {using hold_type = Half; };
|
struct VecHoldType<Vectorized<Half>> {
|
||||||
|
using hold_type = Half;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename VT>
|
template <typename VT>
|
||||||
using vechold_type = typename VecHoldType<VT>::hold_type;
|
using vechold_type = typename VecHoldType<VT>::hold_type;
|
||||||
|
|
||||||
}} // namespace at::vec::CPU_CAPABILITY
|
} // namespace CPU_CAPABILITY
|
||||||
|
} // namespace at::vec
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -28,8 +28,8 @@ struct VecConvert {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename dst_t, typename src_t>
|
template <typename dst_t, typename src_t>
|
||||||
inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>>
|
inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>> convert(
|
||||||
convert(const Vectorized<src_t>& src) {
|
const Vectorized<src_t>& src) {
|
||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,9 @@ static inline void transpose_pad_2x32_block(
|
||||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false, "transpose_pad_2x32_block is only supported when avx512 is supported")
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"transpose_pad_2x32_block is only supported when avx512 is supported")
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -124,28 +126,31 @@ static inline void pack_vnni2(
|
||||||
for (; bk < _K; bk += 2) {
|
for (; bk < _K; bk += 2) {
|
||||||
int64_t bn = 0;
|
int64_t bn = 0;
|
||||||
for (; bn < _N; bn += 32) {
|
for (; bn < _N; bn += 32) {
|
||||||
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src);
|
transpose_pad_2x32_block(
|
||||||
|
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src);
|
||||||
}
|
}
|
||||||
int64_t nrem = N - bn;
|
int64_t nrem = N - bn;
|
||||||
if (nrem > 0) {
|
if (nrem > 0) {
|
||||||
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem);
|
transpose_pad_2x32_block(
|
||||||
|
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (K % 2 == 1) {
|
if (K % 2 == 1) {
|
||||||
int64_t bn = 0;
|
int64_t bn = 0;
|
||||||
for (; bn < _N; bn += 32) {
|
for (; bn < _N; bn += 32) {
|
||||||
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1);
|
transpose_pad_2x32_block(
|
||||||
|
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1);
|
||||||
}
|
}
|
||||||
int64_t nrem = N - bn;
|
int64_t nrem = N - bn;
|
||||||
if (nrem > 0) {
|
if (nrem > 0) {
|
||||||
transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem);
|
transpose_pad_2x32_block(
|
||||||
|
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported")
|
TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported")
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace CPU_CAPABILITY
|
} // namespace CPU_CAPABILITY
|
||||||
} // namespace at::vec
|
} // namespace at::vec
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,12 @@ struct VecMaskTo {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void>
|
template <
|
||||||
|
typename dst_t,
|
||||||
|
int dst_n,
|
||||||
|
typename src_t,
|
||||||
|
int src_n,
|
||||||
|
typename Enabled = void>
|
||||||
struct VecMaskCast {
|
struct VecMaskCast {
|
||||||
static inline VecMask<dst_t, dst_n> apply(
|
static inline VecMask<dst_t, dst_n> apply(
|
||||||
const VecMask<src_t, src_n>& vec_mask) {
|
const VecMask<src_t, src_n>& vec_mask) {
|
||||||
|
|
@ -88,15 +93,17 @@ struct VecMaskCheck {
|
||||||
static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
|
static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
|
||||||
__at_align__ T mask[VectorizedN<T, N>::size()];
|
__at_align__ T mask[VectorizedN<T, N>::size()];
|
||||||
vec_mask.store(mask);
|
vec_mask.store(mask);
|
||||||
return std::all_of(
|
return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
|
||||||
mask, mask + VectorizedN<T, N>::size(), [](T m) { return m == static_cast<T>(0); });
|
return m == static_cast<T>(0);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
|
static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
|
||||||
__at_align__ T mask[VectorizedN<T, N>::size()];
|
__at_align__ T mask[VectorizedN<T, N>::size()];
|
||||||
vec_mask.store(mask);
|
vec_mask.store(mask);
|
||||||
return std::all_of(
|
return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
|
||||||
mask, mask + VectorizedN<T, N>::size(), [](T m) { return m != static_cast<T>(0); });
|
return m != static_cast<T>(0);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
|
static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
|
||||||
|
|
@ -159,13 +166,11 @@ class VecMask {
|
||||||
}
|
}
|
||||||
|
|
||||||
static VecMask<T, N> blendv(
|
static VecMask<T, N> blendv(
|
||||||
const VecMask<T, N>& c,
|
const VecMask<T, N>& c,
|
||||||
const VecMask<T, N>& b,
|
const VecMask<T, N>& b,
|
||||||
const VecMask<T, N>& a) {
|
const VecMask<T, N>& a) {
|
||||||
VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
|
VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
|
||||||
VectorizedN<T, N>(c),
|
VectorizedN<T, N>(c), VectorizedN<T, N>(b), VectorizedN<T, N>(a));
|
||||||
VectorizedN<T, N>(b),
|
|
||||||
VectorizedN<T, N>(a));
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -174,14 +179,14 @@ class VecMask {
|
||||||
const VecMask<T, N>& b,
|
const VecMask<T, N>& b,
|
||||||
int64_t count = size()) {
|
int64_t count = size()) {
|
||||||
VectorizedN<T, N> result = VectorizedN<T, N>::set(
|
VectorizedN<T, N> result = VectorizedN<T, N>::set(
|
||||||
VectorizedN<T, N>(a),
|
VectorizedN<T, N>(a), VectorizedN<T, N>(b), count);
|
||||||
VectorizedN<T, N>(b),
|
|
||||||
count);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void store(bool* b, int count = size()) {
|
void store(bool* b, int count = size()) {
|
||||||
constexpr int L = (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1)/ Vectorized<bool>::size();
|
constexpr int L =
|
||||||
|
(VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1) /
|
||||||
|
Vectorized<bool>::size();
|
||||||
auto res = this->to<bool, L>();
|
auto res = this->to<bool, L>();
|
||||||
res.store(b, count);
|
res.store(b, count);
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user