Enable x86 CPU vectorization on windows [submodule sleef] (#118980)

Enable VEC on Windows OS.
1. Fix some type defination gap between Windows and Linux.
2. Fix some operator not support on Windows, such as [], /.
3. Enable static sleef library build on Windows.
4. Disable unsupported function overloading on MSVC.
5. Upgrade submodule sleef lib, which fixed build issue on Windows.
6. Fixed bazel build issues.
7. Fix test app not link to sleef on Windows.

Note: If rebuild fail after pulled this PR, please sync `sleef` submodule by run:
```cmd
git submodule sync
git submodule update --init --recursive
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118980
Approved by: https://github.com/jgong5, https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Xu Han 2024-03-31 03:07:32 +00:00 committed by PyTorch MergeBot
parent 2b1ba0ceae
commit 56451cd49d
19 changed files with 195 additions and 94 deletions

View File

@ -419,16 +419,8 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
endif()
if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
# Preserve values for the main build
set(__aten_sleef_build_shared_libs ${BUILD_SHARED_LIBS})
set(__aten_sleef_build_tests ${BUILD_TESTS})
# Unset our restrictive C++ flags here and reset them later.
# Remove this once we use proper target_compile_options.
set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(CMAKE_CXX_FLAGS)
if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
if(NOT MSVC)
# Bump up optimization level for sleef to -O1, since at -O0 the compiler
# excessively spills intermediate vector registers to the stack
# and makes things run impossibly slowly
@ -438,13 +430,14 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
else()
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
endif()
endif()
if(NOT USE_SYSTEM_SLEEF)
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
set(BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(SLEEF_BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
@ -465,12 +458,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()
list(APPEND ATen_CPU_DEPENDENCY_LIBS sleef)
if(NOT MSVC)
set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})
# Set these back. TODO: Use SLEEF_ to pass these instead
set(BUILD_SHARED_LIBS ${__aten_sleef_build_shared_libs} CACHE BOOL "Build shared libs" FORCE)
set(BUILD_TESTS ${__aten_sleef_build_tests} CACHE BOOL "Build tests" FORCE)
endif()
endif()
if(USE_CUDA AND NOT USE_ROCM)

View File

@ -72,7 +72,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -97,7 +97,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
@ -109,9 +110,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm256_i32gather_ps(base_addr, vindex, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
@ -125,7 +127,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex, Vectorized<float>& mask) {
return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Only works for inputs in the range: [-2^51, 2^51]
@ -305,6 +307,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}
#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // (defined(CPU_CAPABILITY_AVX2)
}} // namepsace at::vec::CPU_CAPABILITY

View File

@ -7,7 +7,8 @@
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -18,7 +19,18 @@ namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#ifndef SLEEF_CONST
#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
#define SLEEF_CONST const
#else
#define SLEEF_CONST
#endif
#define SLEEF_CONST_OLD SLEEF_CONST
#else
#define SLEEF_CONST_OLD
#endif
// bfloat16 conversion
static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
@ -292,7 +304,8 @@ public:
}
return b;
}
Vectorized<T> map(const __m256 (*const vop)(__m256)) const {
Vectorized<T> map(SLEEF_CONST __m256 (*SLEEF_CONST_OLD vop)(__m256)) const {
__m256 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
const auto o1 = vop(lo);
@ -1053,7 +1066,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#else // defined(CPU_CAPABILITY_AVX2)
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
@ -1106,9 +1119,9 @@ inline Vectorized<Half> convert_float_half(const Vectorized<float>& a, const Vec
CONVERT_NON_VECTORIZED_INIT(Half, half);
#endif
#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // defined(CPU_CAPABILITY_AVX2)
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define LOAD_FP32_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
auto values = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data)); \
@ -1127,7 +1140,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
LOAD_FP32_VECTORIZED_INIT(Half, fp16);
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#else // defined(CPU_CAPABILITY_AVX2)
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
__at_align__ float values[Vectorized<float>::size()]; \

View File

@ -8,7 +8,8 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -16,7 +17,7 @@ namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<c10::complex<double>> {
private:
@ -145,7 +146,7 @@ public:
auto abs = abs_();
auto zero = _mm256_setzero_pd();
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
auto div = values / abs;
auto div = _mm256_div_pd(values, abs);
return _mm256_blendv_pd(div, zero, mask);
}
__m256d real_() const {

View File

@ -7,7 +7,8 @@
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -15,7 +16,7 @@ namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<c10::complex<float>> {
private:
@ -180,7 +181,7 @@ public:
auto abs = abs_();
auto zero = _mm256_setzero_ps();
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
auto div = values / abs;
auto div = _mm256_div_ps(values, abs);
return _mm256_blendv_ps(div, zero, mask);
}
__m256 real_() const {

View File

@ -6,7 +6,8 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -15,7 +16,7 @@ namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<double> {
private:

View File

@ -6,7 +6,8 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -14,7 +15,7 @@ namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<float> {
private:
@ -226,14 +227,14 @@ public:
static __m256 vec_factorial_5 =
_mm256_set1_ps(0.00828929059f); // 1/factorial(5)
static __m256 vec_exp_log2ef =
(__m256)_mm256_set1_epi32(0x3fb8aa3b); // log2(e)
_mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e)
static __m256 vec_half = _mm256_set1_ps(0.5f);
static __m256 vec_one = _mm256_set1_ps(1.f);
static __m256 vec_zero = _mm256_set1_ps(0.f);
static __m256 vec_two = _mm256_set1_ps(2.f);
static __m256 vec_ln2f = (__m256)_mm256_set1_epi32(0x3f317218); // ln(2)
static __m256 vec_ln_flt_min = (__m256)_mm256_set1_epi32(0xc2aeac50);
static __m256 vec_ln_flt_max = (__m256)_mm256_set1_epi32(0x42b17218);
static __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2)
static __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50));
static __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218));
static __m256i vec_127 = _mm256_set1_epi32(0x0000007f);
static int n_mantissa_bits = 23;
@ -266,7 +267,7 @@ public:
auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number);
auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127);
vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
auto vec_two_pow_n = (__m256)vec_two_pow_n_i;
auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i);
vec_two_pow_n =
_mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask);

View File

@ -41,11 +41,17 @@
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#ifdef _MSC_VER
__declspec(align(64)) struct Vectorizedqi {
protected:
__m256i vals;
#else
struct Vectorizedqi {
protected:
__m256i vals __attribute__((aligned(64)));
#endif
public:
Vectorizedqi() {}
@ -133,7 +139,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
}
template <typename T>
inline void __attribute__((always_inline)) QuantizeAvx2(
__FORCE_INLINE void QuantizeAvx2(
const float* src,
T* dst,
int len,
@ -1331,5 +1337,5 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
return a.maximum(b);
}
#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // if defined(CPU_CAPABILITY_AVX2)
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -57,7 +57,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -82,7 +82,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
@ -94,9 +95,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm512_i32gather_ps(vindex, base_addr, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
@ -114,7 +116,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
@ -272,6 +274,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}
#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#endif // defined(CPU_CAPABILITY_AVX512)
}}}

View File

@ -7,7 +7,8 @@
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -16,7 +17,18 @@ namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#ifndef SLEEF_CONST
#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
#define SLEEF_CONST const
#else
#define SLEEF_CONST
#endif
#define SLEEF_CONST_OLD SLEEF_CONST
#else
#define SLEEF_CONST_OLD
#endif
// bfloat16 conversion
static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
@ -367,7 +379,8 @@ public:
}
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wignored-qualifiers"
Vectorized<T> map(const __m512 (*const vop)(__m512)) const {
Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
const auto o1 = vop(lo);
@ -1576,7 +1589,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);
#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#else //defined(CPU_CAPABILITY_AVX512)
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
@ -1606,9 +1619,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_NON_VECTORIZED_INIT(Half, half);
#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#endif // defined(CPU_CAPABILITY_AVX512)
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#define LOAD_FP32_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data)); \
@ -1627,7 +1640,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
LOAD_FP32_VECTORIZED_INIT(Half, fp16);
#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#else // defined(CPU_CAPABILITY_AVX512)
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
__at_align__ float values[Vectorized<float>::size()]; \

View File

@ -7,7 +7,8 @@
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -16,7 +17,7 @@ namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<c10::complex<double>> {
private:
@ -203,7 +204,7 @@ public:
auto abs = abs_();
auto zero = _mm512_setzero_pd();
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
auto div = values / abs;
auto div = _mm512_div_pd(values, abs);
return _mm512_mask_blend_pd(mask, div, zero);
}
__m512d real_() const {

View File

@ -7,7 +7,8 @@
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -16,7 +17,7 @@ namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<c10::complex<float>> {
private:
@ -708,7 +709,7 @@ public:
auto abs = abs_();
auto zero = _mm512_setzero_ps();
auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ);
auto div = values / abs;
auto div = _mm512_div_ps(values, abs);
return _mm512_mask_blend_ps(mask, div, zero);
}
__m512 real_() const {

View File

@ -6,7 +6,8 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
#if (defined(CPU_CAPABILITY_AVX512))
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -15,7 +16,7 @@ namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<double> {
private:

View File

@ -6,7 +6,8 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
@ -15,7 +16,7 @@ namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<float> {
private:
@ -246,14 +247,14 @@ public:
static __m512 vec_factorial_5 =
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
static __m512 vec_exp_log2ef =
(__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e)
static __m512 vec_half = _mm512_set1_ps(0.5f);
static __m512 vec_one = _mm512_set1_ps(1.f);
static __m512 vec_zero = _mm512_set1_ps(0.f);
static __m512 vec_two = _mm512_set1_ps(2.f);
static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
static __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2)
static __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
static __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
static int n_mantissa_bits = 23;
@ -288,7 +289,7 @@ public:
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number);
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127);
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
auto vec_two_pow_n = (__m512)vec_two_pow_n_i;
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i);
vec_two_pow_n =
_mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero);

View File

@ -42,11 +42,17 @@ namespace at {
namespace vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)
#ifdef _MSC_VER
__declspec(align(64)) struct Vectorizedqi {
protected:
__m512i vals;
#else
struct Vectorizedqi {
protected:
__m512i vals __attribute__((aligned(64)));
#endif
public:
Vectorizedqi() {}
@ -136,7 +142,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
}
template <typename T>
inline void __attribute__((always_inline)) QuantizeAvx512(
__FORCE_INLINE void QuantizeAvx512(
const float* src,
T* dst,
int len,
@ -525,10 +531,17 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
Vectorized<float> scale,
Vectorized<float> zero_point,
Vectorized<float> scale_neg_zp_premul) const {
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
__m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
__m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
__m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
#else
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
#endif
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
@ -549,10 +562,17 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
__m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
__m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
__m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
#else
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
#endif
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
@ -598,20 +618,34 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
}
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
__m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
__m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
__m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
#else
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
#endif
__m512i int32_val0 = cvtepi8_epi32(int_val0);
__m512i int32_val1 = cvtepi8_epi32(int_val1);
__m512i int32_val2 = cvtepi8_epi32(int_val2);
__m512i int32_val3 = cvtepi8_epi32(int_val3);
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]);
__m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]);
__m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]);
__m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]);
#else
__m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
__m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
__m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
__m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
#endif
__m512i int32_b0 = cvtepi8_epi32(int_b0);
__m512i int32_b1 = cvtepi8_epi32(int_b1);
@ -721,10 +755,17 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
Vectorized<float> scale,
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
__m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
__m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
__m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
#else
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
#endif
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
@ -746,10 +787,17 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
__m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
__m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
__m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
#else
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
#endif
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
@ -796,20 +844,34 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
}
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
__m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
__m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
__m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
#else
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
#endif
__m512i int32_val0 = cvtepu8_epi32(int_val0);
__m512i int32_val1 = cvtepu8_epi32(int_val1);
__m512i int32_val2 = cvtepu8_epi32(int_val2);
__m512i int32_val3 = cvtepu8_epi32(int_val3);
#if defined(_MSC_VER) && !defined(__clang__)
__m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]);
__m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]);
__m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]);
__m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]);
#else
__m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
__m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
__m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
__m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
#endif
__m512i int32_b0 = cvtepu8_epi32(int_b0);
__m512i int32_b1 = cvtepu8_epi32(int_b1);

View File

@ -36,6 +36,12 @@
#include <c10/util/irange.h>
#include <c10/util/Load.h>
#if defined(__GNUC__)
#define __FORCE_INLINE __attribute__((always_inline)) inline
#elif defined(_MSC_VER)
#define __FORCE_INLINE __forceinline
#endif
// These macros helped us unify vec_base.h
#ifdef CPU_CAPABILITY_AVX512
#if defined(__GNUC__)

View File

@ -1837,7 +1837,7 @@ if(BUILD_TEST)
endif()
else()
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main)
target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library sleef gtest_main)
endif()
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)

2
third_party/sleef vendored

@ -1 +1 @@
Subproject commit e0a003ee838b75d11763aa9c3ef17bf71a725bff
Subproject commit 60e76d2bce17d278b439d9da17177c8f957a9e9b

View File

@ -38,6 +38,7 @@ SLEEF_PUBLIC_HEADERS = [
SLEEF_PRIVATE_INCLUDES = [
"-Iexternal/sleef/src/arch",
"-Iexternal/sleef/src/common",
"-Iexternal/sleef/src/libm",
]
SLEEF_PUBLIC_INCLUDES = [
@ -201,8 +202,6 @@ cc_library(
srcs = [
"src/libm/rempitab.c",
"src/libm/sleefdp.c",
"src/libm/sleefld.c",
"src/libm/sleefqp.c",
"src/libm/sleefsp.c",
],
hdrs = SLEEF_PUBLIC_HEADERS,