diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h index 0ad0c892b06..d1269db5724 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h @@ -5,6 +5,114 @@ namespace at::vec { inline namespace CPU_CAPABILITY { #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) + +// Enable auto-vectorization for GCC-13+ and clang-17+ +// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 +#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17)) + +template +inline void convertImpl( + const from_type* __restrict src, + to_type* __restrict dst, + int64_t n) { + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + dst[i] = static_cast(src[i]); + } +} + +#define CONVERT_TEMPLATE(from_type, to_type) \ + template <> \ + inline void convert(const from_type* src, to_type* dst, int64_t n) { \ + return convertImpl(src, dst, n); \ + } + +CONVERT_TEMPLATE(uint8_t, uint8_t) +CONVERT_TEMPLATE(uint8_t, int8_t) +CONVERT_TEMPLATE(uint8_t, int16_t) +CONVERT_TEMPLATE(uint8_t, int32_t) +CONVERT_TEMPLATE(uint8_t, int64_t) +CONVERT_TEMPLATE(uint8_t, float) +CONVERT_TEMPLATE(uint8_t, double) +CONVERT_TEMPLATE(int8_t, uint8_t) +CONVERT_TEMPLATE(int8_t, int8_t) +CONVERT_TEMPLATE(int8_t, int16_t) +CONVERT_TEMPLATE(int8_t, int32_t) +CONVERT_TEMPLATE(int8_t, int64_t) +CONVERT_TEMPLATE(int8_t, float) +CONVERT_TEMPLATE(int8_t, double) +CONVERT_TEMPLATE(int16_t, uint8_t) +CONVERT_TEMPLATE(int16_t, int8_t) +CONVERT_TEMPLATE(int16_t, int16_t) +CONVERT_TEMPLATE(int16_t, int32_t) +CONVERT_TEMPLATE(int16_t, int64_t) +CONVERT_TEMPLATE(int16_t, float) +CONVERT_TEMPLATE(int16_t, double) +CONVERT_TEMPLATE(int32_t, uint8_t) +CONVERT_TEMPLATE(int32_t, int8_t) +CONVERT_TEMPLATE(int32_t, int16_t) +CONVERT_TEMPLATE(int32_t, int32_t) +CONVERT_TEMPLATE(int32_t, int64_t) +CONVERT_TEMPLATE(int32_t, float) +CONVERT_TEMPLATE(int32_t, double) +CONVERT_TEMPLATE(int64_t, uint8_t) +CONVERT_TEMPLATE(int64_t, int8_t) +CONVERT_TEMPLATE(int64_t, int16_t) +CONVERT_TEMPLATE(int64_t, int32_t) +CONVERT_TEMPLATE(int64_t, int64_t) +CONVERT_TEMPLATE(int64_t, float) +CONVERT_TEMPLATE(int64_t, double) +CONVERT_TEMPLATE(float, uint8_t) +CONVERT_TEMPLATE(float, int8_t) +CONVERT_TEMPLATE(float, int16_t) +CONVERT_TEMPLATE(float, int32_t) +CONVERT_TEMPLATE(float, int64_t) +CONVERT_TEMPLATE(float, float) +CONVERT_TEMPLATE(float, double) +CONVERT_TEMPLATE(double, uint8_t) +CONVERT_TEMPLATE(double, int8_t) +CONVERT_TEMPLATE(double, int16_t) +CONVERT_TEMPLATE(double, int32_t) +CONVERT_TEMPLATE(double, int64_t) +CONVERT_TEMPLATE(double, float) +CONVERT_TEMPLATE(double, double) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +CONVERT_TEMPLATE(float16_t, uint8_t) +CONVERT_TEMPLATE(float16_t, int8_t) +CONVERT_TEMPLATE(float16_t, int16_t) +CONVERT_TEMPLATE(float16_t, int32_t) +CONVERT_TEMPLATE(float16_t, int64_t) +CONVERT_TEMPLATE(float16_t, float16_t) +CONVERT_TEMPLATE(float16_t, float) +CONVERT_TEMPLATE(float16_t, double) +CONVERT_TEMPLATE(uint8_t, float16_t) +CONVERT_TEMPLATE(int8_t, float16_t) +CONVERT_TEMPLATE(int16_t, float16_t) +CONVERT_TEMPLATE(int32_t, float16_t) +CONVERT_TEMPLATE(int64_t, float16_t) +CONVERT_TEMPLATE(float, float16_t) +CONVERT_TEMPLATE(double, float16_t) +#endif +#ifdef __ARM_FEATURE_BF16 +CONVERT_TEMPLATE(bfloat16_t, uint8_t) +CONVERT_TEMPLATE(bfloat16_t, int8_t) +CONVERT_TEMPLATE(bfloat16_t, int16_t) +CONVERT_TEMPLATE(bfloat16_t, int32_t) +CONVERT_TEMPLATE(bfloat16_t, int64_t) +CONVERT_TEMPLATE(bfloat16_t, bfloat16_t) +CONVERT_TEMPLATE(bfloat16_t, float) +CONVERT_TEMPLATE(bfloat16_t, double) +CONVERT_TEMPLATE(uint8_t, bfloat16_t) +CONVERT_TEMPLATE(int8_t, bfloat16_t) +CONVERT_TEMPLATE(int16_t, bfloat16_t) +CONVERT_TEMPLATE(int32_t, bfloat16_t) +CONVERT_TEMPLATE(int64_t, bfloat16_t) +CONVERT_TEMPLATE(float, bfloat16_t) +CONVERT_TEMPLATE(double, bfloat16_t) +#endif + +#endif + template struct VecConvert< float, diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h index c6c34222c5c..da4599ab49f 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -540,42 +540,6 @@ inline Vectorized Vectorized::le( return (*this <= other) & Vectorized(1.0f); } -template <> -inline void convert(const float* src, int32_t* dst, int64_t n) { - int64_t i; -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (i = 0; i <= (n - Vectorized::size()); - i += Vectorized::size()) { - vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); - } -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (; i < n; i++) { - dst[i] = static_cast(src[i]); - } -} - -template <> -inline void convert(const int32_t* src, float* dst, int64_t n) { - int64_t i; -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (i = 0; i <= (n - Vectorized::size()); - i += Vectorized::size()) { - vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); - } -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (; i < n; i++) { - dst[i] = static_cast(src[i]); - } -} - template <> Vectorized inline fmadd( const Vectorized& a, diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h index ab4a5a89cba..e6f1fb88d04 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -569,46 +569,6 @@ inline Vectorized Vectorized::le( return (*this <= other) & Vectorized(1); } -// These are global functions, so the defaults in vec_base.h should -// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available. -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -inline void convert(const float16_t* src, int16_t* dst, int64_t n) { - int64_t i; -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (i = 0; i <= (n - Vectorized::size()); - i += Vectorized::size()) { - vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); - } -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (; i < n; i++) { - dst[i] = static_cast(src[i]); - } -} - -template <> -inline void convert(const int16_t* src, float16_t* dst, int64_t n) { - int64_t i; -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (i = 0; i <= (n - Vectorized::size()); - i += Vectorized::size()) { - vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); - } -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (; i < n; i++) { - dst[i] = static_cast(src[i]); - } -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - template <> Vectorized inline fmadd( const Vectorized& a,