mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Pytorch] Enable autovec on aarch64 for type conversion (#166049)
Summary: Implementing autovec template for type conversions on aarch64-NEON Generated code can be seen here: https://godbolt.org/z/1K6T1d9TE We've seen significant performance improvements for converting to and from bytes, compiling using clang with -march=armv9-a+sve2: Before float->uint8->float ===> 683.212us float->int8->float ===> 687.846us int32->uint8->int32 ===> 497.121us int32->int8->int32 ===> 481.889us After: float->uint8->float ===> 198.204us ----> 245% higher throughput float->int8->float ===> 200.241us ----> 244% higher throughput int32->uint8->int32 ===> 197.970us ----> 151% higher throughput int32->int8->int32 ===> 198.206us ----> 143% higher throughput Test Plan: buck2 test mode/opt //caffe2/test:test_ops buck2 test mode/opt //caffe2/test:torch Differential Revision: D85213420 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166049 Approved by: https://github.com/ezyang, https://github.com/mcfi, https://github.com/aditew01
This commit is contained in:
parent
2efcf3ca98
commit
b31bad1b8f
|
|
@ -5,6 +5,114 @@
|
|||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
|
||||
// Enable auto-vectorization for GCC-13+ and clang-17+
|
||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
||||
|
||||
template <typename from_type, typename to_type>
|
||||
inline void convertImpl(
|
||||
const from_type* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
dst[i] = static_cast<to_type>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TEMPLATE(from_type, to_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
|
||||
return convertImpl<from_type, to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int8_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int16_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int32_t)
|
||||
CONVERT_TEMPLATE(uint8_t, int64_t)
|
||||
CONVERT_TEMPLATE(uint8_t, float)
|
||||
CONVERT_TEMPLATE(uint8_t, double)
|
||||
CONVERT_TEMPLATE(int8_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int8_t)
|
||||
CONVERT_TEMPLATE(int8_t, int16_t)
|
||||
CONVERT_TEMPLATE(int8_t, int32_t)
|
||||
CONVERT_TEMPLATE(int8_t, int64_t)
|
||||
CONVERT_TEMPLATE(int8_t, float)
|
||||
CONVERT_TEMPLATE(int8_t, double)
|
||||
CONVERT_TEMPLATE(int16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int8_t)
|
||||
CONVERT_TEMPLATE(int16_t, int16_t)
|
||||
CONVERT_TEMPLATE(int16_t, int32_t)
|
||||
CONVERT_TEMPLATE(int16_t, int64_t)
|
||||
CONVERT_TEMPLATE(int16_t, float)
|
||||
CONVERT_TEMPLATE(int16_t, double)
|
||||
CONVERT_TEMPLATE(int32_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int8_t)
|
||||
CONVERT_TEMPLATE(int32_t, int16_t)
|
||||
CONVERT_TEMPLATE(int32_t, int32_t)
|
||||
CONVERT_TEMPLATE(int32_t, int64_t)
|
||||
CONVERT_TEMPLATE(int32_t, float)
|
||||
CONVERT_TEMPLATE(int32_t, double)
|
||||
CONVERT_TEMPLATE(int64_t, uint8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int8_t)
|
||||
CONVERT_TEMPLATE(int64_t, int16_t)
|
||||
CONVERT_TEMPLATE(int64_t, int32_t)
|
||||
CONVERT_TEMPLATE(int64_t, int64_t)
|
||||
CONVERT_TEMPLATE(int64_t, float)
|
||||
CONVERT_TEMPLATE(int64_t, double)
|
||||
CONVERT_TEMPLATE(float, uint8_t)
|
||||
CONVERT_TEMPLATE(float, int8_t)
|
||||
CONVERT_TEMPLATE(float, int16_t)
|
||||
CONVERT_TEMPLATE(float, int32_t)
|
||||
CONVERT_TEMPLATE(float, int64_t)
|
||||
CONVERT_TEMPLATE(float, float)
|
||||
CONVERT_TEMPLATE(float, double)
|
||||
CONVERT_TEMPLATE(double, uint8_t)
|
||||
CONVERT_TEMPLATE(double, int8_t)
|
||||
CONVERT_TEMPLATE(double, int16_t)
|
||||
CONVERT_TEMPLATE(double, int32_t)
|
||||
CONVERT_TEMPLATE(double, int64_t)
|
||||
CONVERT_TEMPLATE(double, float)
|
||||
CONVERT_TEMPLATE(double, double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_TEMPLATE(float16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int8_t)
|
||||
CONVERT_TEMPLATE(float16_t, int16_t)
|
||||
CONVERT_TEMPLATE(float16_t, int32_t)
|
||||
CONVERT_TEMPLATE(float16_t, int64_t)
|
||||
CONVERT_TEMPLATE(float16_t, float16_t)
|
||||
CONVERT_TEMPLATE(float16_t, float)
|
||||
CONVERT_TEMPLATE(float16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int8_t, float16_t)
|
||||
CONVERT_TEMPLATE(int16_t, float16_t)
|
||||
CONVERT_TEMPLATE(int32_t, float16_t)
|
||||
CONVERT_TEMPLATE(int64_t, float16_t)
|
||||
CONVERT_TEMPLATE(float, float16_t)
|
||||
CONVERT_TEMPLATE(double, float16_t)
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
|
|
|||
|
|
@ -540,42 +540,6 @@ inline Vectorized<float> Vectorized<float>::le(
|
|||
return (*this <= other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const float* src, int32_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int32_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int32_t* src, float* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||
i += Vectorized<float>::size()) {
|
||||
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline fmadd(
|
||||
const Vectorized<float>& a,
|
||||
|
|
|
|||
|
|
@ -569,46 +569,6 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
|
|||
return (*this <= other) & Vectorized<c10::Half>(1);
|
||||
}
|
||||
|
||||
// These are global functions, so the defaults in vec_base.h should
|
||||
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
template <>
|
||||
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<int16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
|
||||
int64_t i;
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||
i += Vectorized<c10::Half>::size()) {
|
||||
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
|
||||
}
|
||||
#ifndef __msvc_cl__
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; i < n; i++) {
|
||||
dst[i] = static_cast<float16_t>(src[i]);
|
||||
}
|
||||
}
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <>
|
||||
Vectorized<c10::Half> inline fmadd(
|
||||
const Vectorized<c10::Half>& a,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user