Add AVX512 support in ATen & remove AVX support (#61903)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61903

### Remaining Tasks

- [ ] Collate results of benchmarks on two Intel Xeon machines (with & without CUDA, to check if CPU throttling causes issues with GPUs) - make graphs, including Roofline model plots (Intel Advisor can't make them with libgomp, though, but with Intel OpenMP).

### Summary

1. This draft PR produces binaries with with 3 types of ATen kernels - default, AVX2, AVX512 . Using the environment variable `ATEN_AVX512_256=TRUE`  also results in 3 types of kernels, but the compiler can use 32 ymm registers for AVX2, instead of the default 16. ATen kernels for `CPU_CAPABILITY_AVX` have been removed.

2. `nansum` is not using AVX512 kernel right now, as it has poorer accuracy for Float16, than does AVX2 or DEFAULT, whose respective accuracies aren't very good either (#59415).
It was more convenient to disable AVX512 dispatch for all dtypes of `nansum` for now.

3. On Windows , ATen Quantized AVX512 kernels are not being used, as quantization tests are flaky. If `--continue-through-failure` is used, then `test_compare_model_outputs_functional_static` fails. But if this test is skipped, `test_compare_model_outputs_conv_static` fails. If both these tests are skipped, then a third one fails. These are hard to debug right now due to not having access to a Windows machine with AVX512 support, so it was more convenient to disable AVX512 dispatch of all ATen Quantized kernels on Windows for now.

4. One test is currently being skipped -
[test_lstm` in `quantization.bc](https://github.com/pytorch/pytorch/issues/59098) - It fails only on Cascade Lake machines, irrespective of the `ATEN_CPU_CAPABILITY` used, because FBGEMM uses `AVX512_VNNI` on machines that support it. The value of `reduce_range` should be used as `False` on such machines.

The list of the changes is at https://gist.github.com/imaginary-person/4b4fda660534f0493bf9573d511a878d.

Credits to ezyang for proposing `AVX512_256` - these use AVX2 intrinsics but benefit from 32 registers, instead of the 16 ymm registers that AVX2 uses.
Credits to limo1996 for the initial proposal, and for optimizing `hsub_pd` & `hadd_pd`, which didn't have direct AVX512 equivalents, and are being used in some kernels. He also refactored `vec/functional.h` to remove duplicated code.
Credits to quickwritereader for helping fix 4 failing complex multiplication & division tests.

### Testing
1. `vec_test_all_types` was modified to test basic AVX512 support, as tests already existed for AVX2.
Only one test had to be modified, as it was hardcoded for AVX2.
2.  `pytorch_linux_bionic_py3_8_gcc9_coverage_test1` & `pytorch_linux_bionic_py3_8_gcc9_coverage_test2` are now using `linux.2xlarge` instances, as they support AVX512. They were used for testing AVX512 kernels, as AVX512 kernels are being used by default in both of the CI checks. Windows CI checks had already been using machines with AVX512 support.

### Would the downclocking caused by AVX512 pose an issue?

I think it's important to note that AVX2 causes downclocking as well, and the additional downclocking caused by AVX512 may not hamper performance on some Skylake machines & beyond, because of the double vector-size. I think that [this post with verifiable references is a must-read](https://community.intel.com/t5/Software-Tuning-Performance/Unexpected-power-vs-cores-profile-for-MKL-kernels-on-modern-Xeon/m-p/1133869/highlight/true#M6450). Also, AVX512 would _probably not_ hurt performance on a high-end machine, [but measurements are recommended](https://lemire.me/blog/2018/09/07/avx-512-when-and-how-to-use-these-new-instructions/). In case it does, `ATEN_AVX512_256=TRUE` can be used for building PyTorch, as AVX2 can then use 32 ymm registers instead of the default 16. [FBGEMM uses `AVX512_256` only on Xeon D processors](https://github.com/pytorch/FBGEMM/pull/209), which are said to have poor AVX512 performance.

This [official data](https://www.intel.com/content/dam/www/public/us/en/documents/specification-updates/xeon-scalable-spec-update.pdf) is for the Intel Skylake family, and the first link helps understand its significance. Cascade Lake & Ice Lake SP Xeon processors are said to be even better when it comes to AVX512 performance.

Here is the corresponding data for [Cascade Lake](https://cdrdv2.intel.com/v1/dl/getContent/338848) -

![CASCADE LAKE AVX2](https://user-images.githubusercontent.com/76181208/120666172-ffec3f80-c451-11eb-8ea1-8933ccc12a1b.PNG)
![CASCADE LAKE AVX512](https://user-images.githubusercontent.com/76181208/120666190-04b0f380-c452-11eb-9faa-38d233c874c8.PNG)

The corresponding data isn't publicly available for Intel Xeon SP 3rd gen (Ice Lake SP), but [Intel mentioned that the 3rd gen has frequency improvements pertaining to AVX512](https://newsroom.intel.com/wp-content/uploads/sites/11/2021/04/3rd-Gen-Intel-Xeon-Scalable-Platform-Press-Presentation-281884.pdf). Ice Lake SP machines also have 48 KB L1D caches, so that's another reason for AVX512 performance to be better on them.

### Is PyTorch always faster with AVX512?

No, but then PyTorch is not always faster with AVX2 either. Please refer to #60202. The benefit from vectorization is apparent with with small tensors that fit in caches or in kernels that are more compute heavy. For instance, AVX512 or AVX2 would yield no benefit for adding two 64 MB tensors, but adding two 1 MB tensors would do well with AVX2, and even more so with AVX512.

It seems that memory-bound computations, such as adding two 64 MB tensors can be slow with vectorization (depending upon the number of threads used), as the effects of downclocking can then be observed.

Original pull request: https://github.com/pytorch/pytorch/pull/56992

Reviewed By: soulitzer

Differential Revision: D29266289

Pulled By: ezyang

fbshipit-source-id: 2d5e8d1c2307252f22423bbc14f136c67c3e6184
This commit is contained in:
imaginary-person 2021-07-22 08:49:55 -07:00 committed by Facebook GitHub Bot
parent 59d6e07ada
commit 9e53c823b8
63 changed files with 6772 additions and 971 deletions

View File

@ -132,7 +132,9 @@ fi
if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then
export ATEN_CPU_CAPABILITY=default export ATEN_CPU_CAPABILITY=default
elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then
export ATEN_CPU_CAPABILITY=avx export ATEN_CPU_CAPABILITY=default
elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX512-* || $TEST_CONFIG == 'nogpu_NO_AVX512' ]]; then
export ATEN_CPU_CAPABILITY=avx2
fi fi
if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then

View File

@ -1,9 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_cc//cc:defs.bzl", "cc_library")
CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX", "AVX2"] CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
CAPABILITY_COMPILER_FLAGS = { CAPABILITY_COMPILER_FLAGS = {
"AVX2": ["-mavx2", "-mfma"], "AVX2": ["-mavx2", "-mfma"],
"AVX": ["-mavx"],
"DEFAULT": [], "DEFAULT": [],
} }

View File

@ -50,7 +50,7 @@ if(NOT BUILD_LITE_INTERPRETER)
endif() endif()
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS}) EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h") file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h")
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")

View File

@ -108,12 +108,12 @@ std::string used_cpu_capability() {
case native::CPUCapability::DEFAULT: case native::CPUCapability::DEFAULT:
ss << "NO AVX"; ss << "NO AVX";
break; break;
case native::CPUCapability::AVX:
ss << "AVX";
break;
case native::CPUCapability::AVX2: case native::CPUCapability::AVX2:
ss << "AVX2"; ss << "AVX2";
break; break;
case native::CPUCapability::AVX512:
ss << "AVX512";
break;
#endif #endif
default: default:
break; break;

View File

@ -1,6 +1,5 @@
#include <ATen/cpu/FlushDenormal.h> #include <ATen/cpu/FlushDenormal.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/intrinsics.h>
#include <cpuinfo.h> #include <cpuinfo.h>
namespace at { namespace cpu { namespace at { namespace cpu {

View File

@ -1 +1,6 @@
#include <ATen/cpu/vec/vec256/functional.h> #pragma once
#include <ATen/cpu/vec/functional_base.h>
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
#include <ATen/cpu/vec/functional_bfloat16.h>
#endif

View File

@ -3,7 +3,7 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/vec256.h> #include <ATen/cpu/vec/vec.h>
namespace at { namespace vec { namespace at { namespace vec {

View File

@ -3,7 +3,7 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/functional_base.h> #include <ATen/cpu/vec/vec.h>
namespace at { namespace vec { namespace at { namespace vec {
@ -15,26 +15,26 @@ template <> struct VecScalarType<BFloat16> { using type = float; };
template <typename scalar_t> template <typename scalar_t>
using vec_scalar_t = typename VecScalarType<scalar_t>::type; using vec_scalar_t = typename VecScalarType<scalar_t>::type;
// Note that we already have specializes member of Vectorized<scalar_t> for BFloat16 // Note that we already have specialized member of Vectorized<scalar_t> for BFloat16
// so the following function would run smoothly: // so the following functions would run smoothly:
// using Vec = Vectorized<BFloat16>; // using Vec = Vectorized<BFloat16>;
// Vec one = Vec(BFloat16(1)); // Vec one = Vec(BFloat16(1));
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
// //
// Why we still need to specializes "funtional"? // Then why we still need to specialize "funtional"?
// If we do specialization at Vectorized<> level, the above example would need 3 pairs of // If we do specialization at Vectorized<> level, the above example would need 3 pairs of
// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". // conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
// If we do specialization at vec::map<>() level, we have only 1 pair of conversion // If we do specialization at vec::map<>() level, we have only 1 pair of conversion
// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. // of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
// //
// The following BFloat16 functionalities will only do data type conversion for input // The following BFloat16 functionality will only do data type conversion for input
// and output vector (reduce functionalities will only convert the final scalar back to bf16). // and output vector (reduce functionality will only convert the final scalar back to bf16).
// Compared to Vectorized<> specialization, // Compared to Vectorized<> specialization,
// 1. better performance since we have less data type conversion; // 1. better performance since we have less data type conversion;
// 2. less rounding error since immediate results are kept in fp32; // 2. less rounding error since immediate results are kept in fp32;
// 3. accumulation done on data type of fp32. // 3. accumulation done on data type of fp32.
// //
// If you plan to extend this file, make sure add unit test at // If you plan to extend this file, please ensure adding unit tests at
// aten/src/ATen/test/vec_test_all_types.cpp // aten/src/ATen/test/vec_test_all_types.cpp
// //
template <typename scalar_t = BFloat16, typename Op> template <typename scalar_t = BFloat16, typename Op>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) #if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* Clang-compatible compiler, targeting x86/x86-64 */ /* GCC or clang-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h> #include <x86intrin.h>
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) #elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */ /* Clang-compatible compiler, targeting arm neon */
@ -14,9 +14,6 @@
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) #define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) #define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif #endif
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */ /* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h> #include <arm_neon.h>

View File

@ -1 +1,5 @@
#if defined(CPU_CAPABILITY_AVX512)
#include <ATen/cpu/vec/vec512/vec512.h>
#else
#include <ATen/cpu/vec/vec256/vec256.h> #include <ATen/cpu/vec/vec256/vec256.h>
#endif

View File

@ -1,6 +0,0 @@
#pragma once
#include <ATen/cpu/vec/vec256/functional_base.h>
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
#include <ATen/cpu/vec/vec256/functional_bfloat16.h>
#endif

View File

@ -3,9 +3,9 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) #if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX)
#include <ATen/cpu/vec/vec256/vec256_float.h> #include <ATen/cpu/vec/vec256/vec256_float.h>
#include <ATen/cpu/vec/vec256/vec256_float_neon.h> #include <ATen/cpu/vec/vec256/vec256_float_neon.h>
@ -68,9 +68,9 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
} }
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<> template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) { inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
@ -82,29 +82,6 @@ inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return _mm256_castps_pd(src); return _mm256_castps_pd(src);
} }
#if defined(CPU_CAPABILITY_AVX2)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#define DEFINE_FLOAT_INT_CAST(int_t, float_t, float_ch) \
template<> \
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
return _mm256_castp ## float_ch ## _si256(src); \
} \
template<> \
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
return _mm256_castsi256_p ## float_ch (src); \
}
DEFINE_FLOAT_INT_CAST(int64_t, double, d)
DEFINE_FLOAT_INT_CAST(int32_t, double, d)
DEFINE_FLOAT_INT_CAST(int16_t, double, d)
DEFINE_FLOAT_INT_CAST(int64_t, float, s)
DEFINE_FLOAT_INT_CAST(int32_t, float, s)
DEFINE_FLOAT_INT_CAST(int16_t, float, s)
#undef DEFINE_FLOAT_INT_CAST
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<int64_t scale = 1> template<int64_t scale = 1>
@ -243,8 +220,6 @@ inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>&
_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
} }
#endif // defined(CPU_CAPABILITY_AVX2) #endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
}}} }}}

View File

@ -3,8 +3,8 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#include <sleef.h> #include <sleef.h>
#endif #endif
@ -100,7 +100,7 @@ public:
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
} }
static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) { static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) {
__at_align32__ int16_t tmp_values[size()]; __at_align__ int16_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
return loadu(tmp_values); return loadu(tmp_values);
} }
@ -108,14 +108,14 @@ public:
if (count == size()) { if (count == size()) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) { } else if (count > 0) {
__at_align32__ int16_t tmp_values[size()]; __at_align__ int16_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
} }
} }
template <int64_t mask> template <int64_t mask>
static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
__at_align32__ int16_t tmp_values[size()]; __at_align__ int16_t tmp_values[size()];
a.store(tmp_values); a.store(tmp_values);
if (mask & 0x01) if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi16(b.values, 0); tmp_values[0] = _mm256_extract_epi16(b.values, 0);
@ -280,7 +280,7 @@ public:
Vectorized<BFloat16> erfinv() const { Vectorized<BFloat16> erfinv() const {
__m256 lo, hi; __m256 lo, hi;
cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(values, lo, hi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) { for (int64_t i = 0; i < size() / 2; i++) {
@ -318,7 +318,7 @@ public:
Vectorized<BFloat16> i0() const { Vectorized<BFloat16> i0() const {
__m256 lo, hi; __m256 lo, hi;
cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(values, lo, hi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) { for (int64_t i = 0; i < size() / 2; i++) {
@ -333,7 +333,7 @@ public:
__m256 lo, hi; __m256 lo, hi;
cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(values, lo, hi);
constexpr auto sz = size(); constexpr auto sz = size();
__at_align32__ float tmp1[sz / 2], tmp2[sz / 2]; __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
@ -350,10 +350,10 @@ public:
__m256 xlo, xhi; __m256 xlo, xhi;
cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(x.values, xlo, xhi); cvtbf16_fp32(x.values, xlo, xhi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo); _mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi); _mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) { for (int64_t i = 0; i < size() / 2; ++i) {
@ -370,10 +370,10 @@ public:
__m256 xlo, xhi; __m256 xlo, xhi;
cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(x.values, xlo, xhi); cvtbf16_fp32(x.values, xlo, xhi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo); _mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi); _mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) { for (int64_t i = 0; i < size() / 2; ++i) {
@ -717,12 +717,13 @@ inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, c
return cvtfp32_bf16(__m256(a), __m256(b)); return cvtfp32_bf16(__m256(a), __m256(b));
} }
#else //defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) { inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
constexpr int64_t K = Vectorized<BFloat16>::size(); constexpr int64_t K = Vectorized<BFloat16>::size();
__at_align32__ float arr[K]; __at_align__ float arr[K];
__at_align32__ BFloat16 arr2[K]; __at_align__ BFloat16 arr2[K];
a.store(arr2); a.store(arr2);
convert(arr2, arr, K); convert(arr2, arr, K);
return std::make_tuple( return std::make_tuple(
@ -732,15 +733,15 @@ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(c
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) { inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
constexpr int64_t K = Vectorized<BFloat16>::size(); constexpr int64_t K = Vectorized<BFloat16>::size();
__at_align32__ float arr[K]; __at_align__ float arr[K];
__at_align32__ BFloat16 arr2[K]; __at_align__ BFloat16 arr2[K];
a.store(arr); a.store(arr);
b.store(arr + Vectorized<float>::size()); b.store(arr + Vectorized<float>::size());
convert(arr, arr2, K); convert(arr, arr2, K);
return Vectorized<BFloat16>::loadu(arr2); return Vectorized<BFloat16>::loadu(arr2);
} }
#endif #endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) { void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
@ -759,7 +760,7 @@ void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vec
} }
#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) { void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
__at_align32__ float values[Vectorized<float>::size()]; __at_align__ float values[Vectorized<float>::size()];
for (int k = 0; k < Vectorized<float>::size(); ++k) { for (int k = 0; k < Vectorized<float>::size(); ++k) {
values[k] = data[k]; values[k] = data[k];
} }

View File

@ -4,9 +4,10 @@
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <c10/util/complex.h> #include <c10/util/complex.h>
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#include <sleef.h> #include <sleef.h>
#endif #endif
@ -15,7 +16,7 @@ namespace vec {
// See Note [Acceptable use of anonymous namespace in header] // See Note [Acceptable use of anonymous namespace in header]
namespace { namespace {
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
template <> class Vectorized<c10::complex<double>> { template <> class Vectorized<c10::complex<double>> {
private: private:
@ -81,7 +82,7 @@ public:
if (count == size()) if (count == size())
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr)); return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align32__ double tmp_values[2*size()]; __at_align__ double tmp_values[2*size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -106,7 +107,7 @@ public:
const c10::complex<double>& operator[](int idx) const = delete; const c10::complex<double>& operator[](int idx) const = delete;
c10::complex<double>& operator[](int idx) = delete; c10::complex<double>& operator[](int idx) = delete;
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const { Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
__at_align32__ c10::complex<double> tmp[size()]; __at_align__ c10::complex<double> tmp[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -288,8 +289,8 @@ public:
return sqrt().reciprocal(); return sqrt().reciprocal();
} }
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const { Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
__at_align32__ c10::complex<double> x_tmp[size()]; __at_align__ c10::complex<double> x_tmp[size()];
__at_align32__ c10::complex<double> y_tmp[size()]; __at_align__ c10::complex<double> y_tmp[size()];
store(x_tmp); store(x_tmp);
exp.store(y_tmp); exp.store(y_tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {

View File

@ -4,9 +4,9 @@
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <c10/util/complex.h> #include <c10/util/complex.h>
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#include <sleef.h> #include <sleef.h>
#endif #endif
@ -15,7 +15,7 @@ namespace vec {
// See Note [Acceptable use of anonymous namespace in header] // See Note [Acceptable use of anonymous namespace in header]
namespace { namespace {
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
template <> class Vectorized<c10::complex<float>> { template <> class Vectorized<c10::complex<float>> {
private: private:
@ -117,7 +117,7 @@ public:
if (count == size()) if (count == size())
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr)); return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
__at_align32__ float tmp_values[2*size()]; __at_align__ float tmp_values[2*size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -142,7 +142,7 @@ public:
const c10::complex<float>& operator[](int idx) const = delete; const c10::complex<float>& operator[](int idx) const = delete;
c10::complex<float>& operator[](int idx) = delete; c10::complex<float>& operator[](int idx) = delete;
Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const { Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const {
__at_align32__ c10::complex<float> tmp[size()]; __at_align__ c10::complex<float> tmp[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -323,8 +323,8 @@ public:
return sqrt().reciprocal(); return sqrt().reciprocal();
} }
Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const { Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const {
__at_align32__ c10::complex<float> x_tmp[size()]; __at_align__ c10::complex<float> x_tmp[size()];
__at_align32__ c10::complex<float> y_tmp[size()]; __at_align__ c10::complex<float> y_tmp[size()];
store(x_tmp); store(x_tmp);
exp.store(y_tmp); exp.store(y_tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {

View File

@ -3,9 +3,9 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#include <sleef.h> #include <sleef.h>
#endif #endif
@ -14,7 +14,8 @@ namespace vec {
// See Note [Acceptable use of anonymous namespace in header] // See Note [Acceptable use of anonymous namespace in header]
namespace { namespace {
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
template <> class Vectorized<double> { template <> class Vectorized<double> {
private: private:
@ -67,7 +68,7 @@ public:
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr)); return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align32__ double tmp_values[size()]; __at_align__ double tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -100,7 +101,7 @@ public:
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
} }
Vectorized<double> map(double (*const f)(double)) const { Vectorized<double> map(double (*const f)(double)) const {
__at_align32__ double tmp[size()]; __at_align__ double tmp[size()];
store(tmp); store(tmp);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -175,8 +176,8 @@ public:
return map(calc_i0e); return map(calc_i0e);
} }
Vectorized<double> igamma(const Vectorized<double> &x) const { Vectorized<double> igamma(const Vectorized<double> &x) const {
__at_align32__ double tmp[size()]; __at_align__ double tmp[size()];
__at_align32__ double tmp_x[size()]; __at_align__ double tmp_x[size()];
store(tmp); store(tmp);
x.store(tmp_x); x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -185,8 +186,8 @@ public:
return loadu(tmp); return loadu(tmp);
} }
Vectorized<double> igammac(const Vectorized<double> &x) const { Vectorized<double> igammac(const Vectorized<double> &x) const {
__at_align32__ double tmp[size()]; __at_align__ double tmp[size()];
__at_align32__ double tmp_x[size()]; __at_align__ double tmp_x[size()];
store(tmp); store(tmp);
x.store(tmp_x); x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {

View File

@ -3,9 +3,9 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#include <sleef.h> #include <sleef.h>
#endif #endif
@ -14,7 +14,7 @@ namespace vec {
// See Note [Acceptable use of anonymous namespace in header] // See Note [Acceptable use of anonymous namespace in header]
namespace { namespace {
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
template <> class Vectorized<float> { template <> class Vectorized<float> {
private: private:
@ -76,7 +76,7 @@ public:
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) if (count == size())
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr)); return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
__at_align32__ float tmp_values[size()]; __at_align__ float tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -107,7 +107,7 @@ public:
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
} }
Vectorized<float> map(float (*const f)(float)) const { Vectorized<float> map(float (*const f)(float)) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
store(tmp); store(tmp);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -213,8 +213,8 @@ public:
return map(calc_i0e); return map(calc_i0e);
} }
Vectorized<float> igamma(const Vectorized<float> &x) const { Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_x[size()]; __at_align__ float tmp_x[size()];
store(tmp); store(tmp);
x.store(tmp_x); x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -223,8 +223,8 @@ public:
return loadu(tmp); return loadu(tmp);
} }
Vectorized<float> igammac(const Vectorized<float> &x) const { Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_x[size()]; __at_align__ float tmp_x[size()];
store(tmp); store(tmp);
x.store(tmp_x); x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -412,12 +412,11 @@ inline void convert(const float* src, float* dst, int64_t n) {
} }
} }
#ifdef CPU_CAPABILITY_AVX2
template <> template <>
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return _mm256_fmadd_ps(a, b, c); return _mm256_fmadd_ps(a, b, c);
} }
#endif
#endif #endif

View File

@ -3,8 +3,8 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
// Sleef offers vectorized versions of some transcedentals // Sleef offers vectorized versions of some transcedentals
// such as sin, cos, tan etc.. // such as sin, cos, tan etc..
// However for now opting for STL, since we are not building // However for now opting for STL, since we are not building
@ -220,7 +220,7 @@ public:
return res; return res;
} }
else { else {
__at_align32__ float tmp_values[size()]; __at_align__ float tmp_values[size()];
for (auto i = 0; i < size(); ++i) { for (auto i = 0; i < size(); ++i) {
tmp_values[i] = 0.0; tmp_values[i] = 0.0;
} }
@ -261,19 +261,19 @@ public:
// Once we specialize that implementation for ARM // Once we specialize that implementation for ARM
// this should be removed. TODO (kimishpatel) // this should be removed. TODO (kimishpatel)
float operator[](int idx) const { float operator[](int idx) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
store(tmp); store(tmp);
return tmp[idx]; return tmp[idx];
} }
float operator[](int idx) { float operator[](int idx) {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
store(tmp); store(tmp);
return tmp[idx]; return tmp[idx];
} }
// For boolean version where we want to if any 1/all zero // For boolean version where we want to if any 1/all zero
// etc. can be done faster in a different way. // etc. can be done faster in a different way.
int zero_mask() const { int zero_mask() const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
store(tmp); store(tmp);
int mask = 0; int mask = 0;
for (int i = 0; i < size(); ++ i) { for (int i = 0; i < size(); ++ i) {
@ -284,8 +284,8 @@ public:
return mask; return mask;
} }
Vectorized<float> isnan() const { Vectorized<float> isnan() const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float res[size()]; __at_align__ float res[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
if (_isnan(tmp[i])) { if (_isnan(tmp[i])) {
@ -297,7 +297,7 @@ public:
return loadu(res); return loadu(res);
}; };
Vectorized<float> map(float (*const f)(float)) const { Vectorized<float> map(float (*const f)(float)) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
store(tmp); store(tmp);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -332,8 +332,8 @@ public:
return map(std::atan); return map(std::atan);
} }
Vectorized<float> atan2(const Vectorized<float> &exp) const { Vectorized<float> atan2(const Vectorized<float> &exp) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_exp[size()]; __at_align__ float tmp_exp[size()];
store(tmp); store(tmp);
exp.store(tmp_exp); exp.store(tmp_exp);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -342,8 +342,8 @@ public:
return loadu(tmp); return loadu(tmp);
} }
Vectorized<float> copysign(const Vectorized<float> &sign) const { Vectorized<float> copysign(const Vectorized<float> &sign) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_sign[size()]; __at_align__ float tmp_sign[size()];
store(tmp); store(tmp);
sign.store(tmp_sign); sign.store(tmp_sign);
for (size_type i = 0; i < size(); i++) { for (size_type i = 0; i < size(); i++) {
@ -367,8 +367,8 @@ public:
return map(std::expm1); return map(std::expm1);
} }
Vectorized<float> fmod(const Vectorized<float>& q) const { Vectorized<float> fmod(const Vectorized<float>& q) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_q[size()]; __at_align__ float tmp_q[size()];
store(tmp); store(tmp);
q.store(tmp_q); q.store(tmp_q);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -377,8 +377,8 @@ public:
return loadu(tmp); return loadu(tmp);
} }
Vectorized<float> hypot(const Vectorized<float> &b) const { Vectorized<float> hypot(const Vectorized<float> &b) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_b[size()]; __at_align__ float tmp_b[size()];
store(tmp); store(tmp);
b.store(tmp_b); b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -393,8 +393,8 @@ public:
return map(calc_i0e); return map(calc_i0e);
} }
Vectorized<float> igamma(const Vectorized<float> &x) const { Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_x[size()]; __at_align__ float tmp_x[size()];
store(tmp); store(tmp);
x.store(tmp_x); x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -403,8 +403,8 @@ public:
return loadu(tmp); return loadu(tmp);
} }
Vectorized<float> igammac(const Vectorized<float> &x) const { Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_x[size()]; __at_align__ float tmp_x[size()];
store(tmp); store(tmp);
x.store(tmp_x); x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -425,8 +425,8 @@ public:
return map(std::log2); return map(std::log2);
} }
Vectorized<float> nextafter(const Vectorized<float> &b) const { Vectorized<float> nextafter(const Vectorized<float> &b) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_b[size()]; __at_align__ float tmp_b[size()];
store(tmp); store(tmp);
b.store(tmp_b); b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
@ -490,8 +490,8 @@ public:
return this->sqrt().reciprocal(); return this->sqrt().reciprocal();
} }
Vectorized<float> pow(const Vectorized<float> &exp) const { Vectorized<float> pow(const Vectorized<float> &exp) const {
__at_align32__ float tmp[size()]; __at_align__ float tmp[size()];
__at_align32__ float tmp_exp[size()]; __at_align__ float tmp_exp[size()];
store(tmp); store(tmp);
exp.store(tmp_exp); exp.store(tmp_exp);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {

View File

@ -3,8 +3,8 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
namespace at { namespace at {
@ -55,7 +55,7 @@ public:
} }
template <int64_t mask> template <int64_t mask>
static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) { static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) {
__at_align32__ int64_t tmp_values[size()]; __at_align__ int64_t tmp_values[size()];
a.store(tmp_values); a.store(tmp_values);
if (mask & 0x01) if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi64(b.values, 0); tmp_values[0] = _mm256_extract_epi64(b.values, 0);
@ -93,7 +93,7 @@ public:
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
} }
static Vectorized<int64_t> loadu(const void* ptr, int64_t count) { static Vectorized<int64_t> loadu(const void* ptr, int64_t count) {
__at_align32__ int64_t tmp_values[size()]; __at_align__ int64_t tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -109,7 +109,7 @@ public:
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) { } else if (count > 0) {
__at_align32__ int64_t tmp_values[size()]; __at_align__ int64_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
} }
@ -216,7 +216,7 @@ public:
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
} }
static Vectorized<int32_t> loadu(const void* ptr, int32_t count) { static Vectorized<int32_t> loadu(const void* ptr, int32_t count) {
__at_align32__ int32_t tmp_values[size()]; __at_align__ int32_t tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -232,7 +232,7 @@ public:
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) { } else if (count > 0) {
__at_align32__ int32_t tmp_values[size()]; __at_align__ int32_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
} }
@ -346,7 +346,7 @@ public:
} }
template <int64_t mask> template <int64_t mask>
static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) { static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) {
__at_align32__ int16_t tmp_values[size()]; __at_align__ int16_t tmp_values[size()];
a.store(tmp_values); a.store(tmp_values);
if (mask & 0x01) if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi16(b.values, 0); tmp_values[0] = _mm256_extract_epi16(b.values, 0);
@ -436,7 +436,7 @@ public:
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
} }
static Vectorized<int16_t> loadu(const void* ptr, int16_t count) { static Vectorized<int16_t> loadu(const void* ptr, int16_t count) {
__at_align32__ int16_t tmp_values[size()]; __at_align__ int16_t tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -452,7 +452,7 @@ public:
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) { } else if (count > 0) {
__at_align32__ int16_t tmp_values[size()]; __at_align__ int16_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
} }
@ -527,7 +527,7 @@ public:
} }
template <int64_t mask> template <int64_t mask>
static Vectorized<int8_t> blend(Vectorized<int8_t> a, Vectorized<int8_t> b) { static Vectorized<int8_t> blend(Vectorized<int8_t> a, Vectorized<int8_t> b) {
__at_align32__ int8_t tmp_values[size()]; __at_align__ int8_t tmp_values[size()];
a.store(tmp_values); a.store(tmp_values);
if (mask & 0x01) if (mask & 0x01)
tmp_values[0] = _mm256_extract_epi8(b.values, 0); tmp_values[0] = _mm256_extract_epi8(b.values, 0);
@ -685,7 +685,7 @@ public:
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
} }
static Vectorized<int8_t> loadu(const void* ptr, int8_t count) { static Vectorized<int8_t> loadu(const void* ptr, int8_t count) {
__at_align32__ int8_t tmp_values[size()]; __at_align__ int8_t tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction. // instructions while a loop would be compiled to one instruction.
@ -701,7 +701,7 @@ public:
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
} else if (count > 0) { } else if (count > 0) {
__at_align32__ int8_t tmp_values[size()]; __at_align__ int8_t tmp_values[size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); std::memcpy(ptr, tmp_values, count * sizeof(int8_t));
} }

View File

@ -3,8 +3,8 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER! // DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX] // See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/native/quantized/affine_quantizer_base.h> #include <ATen/native/quantized/affine_quantizer_base.h>
#include <c10/util/qint32.h> #include <c10/util/qint32.h>
#include <c10/util/qint8.h> #include <c10/util/qint8.h>
@ -39,7 +39,7 @@ namespace at {
namespace vec { namespace vec {
namespace { namespace {
#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
struct Vectorizedqi { struct Vectorizedqi {
protected: protected:
@ -53,7 +53,6 @@ struct Vectorizedqi {
} }
}; };
#if defined(CPU_CAPABILITY_AVX2)
template <typename T> template <typename T>
__m256i pack_saturate_and_clamp( __m256i pack_saturate_and_clamp(
__m256i first, __m256i first,
@ -94,7 +93,6 @@ __m256i pack_saturate_and_clamp<uint8_t>(
_mm256_set1_epi8(min_val), _mm256_set1_epi8(min_val),
_mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
} }
#endif
template <typename T> template <typename T>
inline void __attribute__((always_inline)) QuantizeAvx2( inline void __attribute__((always_inline)) QuantizeAvx2(
@ -103,7 +101,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
int len, int len,
float inverse_scale, float inverse_scale,
int64_t zero_point) { int64_t zero_point) {
#if defined(CPU_CAPABILITY_AVX2)
constexpr int VLEN = 8; constexpr int VLEN = 8;
constexpr auto min_val = std::numeric_limits<typename T::underlying>::min(); constexpr auto min_val = std::numeric_limits<typename T::underlying>::min();
constexpr auto max_val = std::numeric_limits<typename T::underlying>::max(); constexpr auto max_val = std::numeric_limits<typename T::underlying>::max();
@ -212,10 +209,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
std::min(std::max(transformed, float(min_val)), float(max_val)); std::min(std::max(transformed, float(min_val)), float(max_val));
dst[i] = clipped; dst[i] = clipped;
} }
#else
at::native::quantize_vec<T>(
1.0f / inverse_scale, zero_point, src, reinterpret_cast<T*>(dst), len);
#endif
} }
template<> template<>
@ -266,11 +259,7 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
Vectorized<float> zero_point, Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const { Vectorized<float> scale_zp_premul) const {
__m256 float_vals = _mm256_cvtepi32_ps(vals); __m256 float_vals = _mm256_cvtepi32_ps(vals);
#if defined(CPU_CAPABILITY_AVX2)
return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)}; return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
#else
return {scale * (Vectorized<float>(float_vals) - zero_point)};
#endif
} }
static Vectorized<c10::qint32> quantize( static Vectorized<c10::qint32> quantize(
@ -286,39 +275,11 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
} }
Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const { Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_max_epi32(vals, b.vals); return _mm256_max_epi32(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(int_vals.data()), vals);
std::array<int32_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(b_vals.data()), b.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::max<int32_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const { Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epi32(vals, b.vals); return _mm256_min_epi32(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int32_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int32_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const { Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
@ -328,65 +289,24 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
Vectorized<c10::qint32> relu6( Vectorized<c10::qint32> relu6(
Vectorized<c10::qint32> zero_point, Vectorized<c10::qint32> zero_point,
Vectorized<c10::qint32> q_six) { Vectorized<c10::qint32> q_six) {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epi32( return _mm256_min_epi32(
_mm256_max_epi32(vals, zero_point.vals), q_six.vals); _mm256_max_epi32(vals, zero_point.vals), q_six.vals);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int32_t, size()> zero_point_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
std::array<int32_t,size()> q_six_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int32_t>(
std::max<int32_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const { int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
#ifdef CPU_CAPABILITY_AVX2
return {_mm256_sub_epi32(vals, b)}; return {_mm256_sub_epi32(vals, b)};
#else
std::array<int32_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int32_t, size()> b_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = int_vals[i] - b_vals[i];
}
return {_mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals))};
#endif
} }
static Vectorized<c10::qint32> requantize_from_int( static Vectorized<c10::qint32> requantize_from_int(
const int_vec_return_type& inp, const int_vec_return_type& inp,
float multiplier, float multiplier,
int32_t zero_point) { int32_t zero_point) {
#ifdef CPU_CAPABILITY_AVX2
__m256 multiplier_v = _mm256_set1_ps(multiplier); __m256 multiplier_v = _mm256_set1_ps(multiplier);
__m256i zero_point_v = _mm256_set1_epi32(zero_point); __m256i zero_point_v = _mm256_set1_epi32(zero_point);
__m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v); __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v);
__m256i rounded = _mm256_cvtps_epi32(scaled); __m256i rounded = _mm256_cvtps_epi32(scaled);
return _mm256_add_epi32(rounded, zero_point_v); return _mm256_add_epi32(rounded, zero_point_v);
#else
std::array<int32_t,size()> inp_vals;
inp[0].store(inp_vals.data());
std::array<int32_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] =
nearbyint(static_cast<float>(inp_vals[i]) * multiplier) +
zero_point;
}
return loadu(result_vals.data());
#endif
} }
void dump() const { void dump() const {
@ -411,43 +331,16 @@ template <>
Vectorized<c10::qint32> inline operator*( Vectorized<c10::qint32> inline operator*(
const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& a,
const Vectorized<c10::qint32>& b) { const Vectorized<c10::qint32>& b) {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_mullo_epi32(a, b); return _mm256_mullo_epi32(a, b);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, std::decay_t<decltype(a)>::size()> a_vals;
std::array<int32_t, std::decay_t<decltype(b)>::size()> b_vals;
a.store(a_vals.data());
b.store(b_vals.data());
std::array<int32_t, std::decay_t<decltype(a)>::size()> result_vals;
for (size_t i = 0; i < std::decay_t<decltype(a)>::size(); ++i) {
result_vals[i] = a_vals[i] * b_vals[i];
}
return Vectorized<c10::qint32>::loadu(result_vals.data());
#endif
} }
template <> template <>
Vectorized<c10::qint32> inline operator+( Vectorized<c10::qint32> inline operator+(
const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& a,
const Vectorized<c10::qint32>& b) { const Vectorized<c10::qint32>& b) {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_add_epi32(a, b); return _mm256_add_epi32(a, b);
#else
// Pray the compiler can autovectorize this
std::array<int32_t, std::decay_t<decltype(a)>::size()> a_vals;
std::array<int32_t, std::decay_t<decltype(b)>::size()> b_vals;
a.store(a_vals.data());
b.store(b_vals.data());
std::array<int32_t, std::decay_t<decltype(a)>::size()> result_vals;
for (size_t i = 0; i < std::decay_t<decltype(a)>::size(); ++i) {
result_vals[i] = a_vals[i] + b_vals[i];
}
return Vectorized<c10::qint32>::loadu(result_vals.data());
#endif
} }
#ifdef CPU_CAPABILITY_AVX2
/* /*
* Convert values from int32 back to int8/uint8 * Convert values from int32 back to int8/uint8
*/ */
@ -493,7 +386,6 @@ __m256i RequantizeAvx2(
xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
return xyzw_clamped_v; return xyzw_clamped_v;
} }
#endif
template<> template<>
struct Vectorized<c10::qint8> : public Vectorizedqi { struct Vectorized<c10::qint8> : public Vectorizedqi {
@ -544,21 +436,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
private: private:
__m256i cvtepi8_epi32(__m128i epi8_vals) const { __m256i cvtepi8_epi32(__m128i epi8_vals) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_cvtepi8_epi32(epi8_vals); return _mm256_cvtepi8_epi32(epi8_vals);
#else // CPU_CAPABILITY_AVX2
__m128i result_data[2];
__m128i unpacked1 = _mm_unpacklo_epi8(epi8_vals, epi8_vals);
__m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, unpacked1);
__m128i shifted1 = _mm_srli_si128(epi8_vals, 4);
__m128i shifted2 = _mm_srai_epi32(unpacked2, 24);
result_data[0] = shifted2;
__m128i unpacked3 = _mm_unpacklo_epi8(shifted1, shifted1);
__m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, unpacked3);
__m128i shifted3 = _mm_srai_epi32(unpacked4, 24);
result_data[1] = shifted3;
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data));
#endif
} }
public: public:
@ -576,7 +454,6 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
#if defined(CPU_CAPABILITY_AVX2)
auto val0 = auto val0 =
vec::fmadd(scale, Vectorized<float>(float_val0), scale_neg_zp_premul); vec::fmadd(scale, Vectorized<float>(float_val0), scale_neg_zp_premul);
auto val1 = auto val1 =
@ -585,12 +462,6 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
vec::fmadd(scale, Vectorized<float>(float_val2), scale_neg_zp_premul); vec::fmadd(scale, Vectorized<float>(float_val2), scale_neg_zp_premul);
auto val3 = auto val3 =
vec::fmadd(scale, Vectorized<float>(float_val3), scale_neg_zp_premul); vec::fmadd(scale, Vectorized<float>(float_val3), scale_neg_zp_premul);
#else
auto val0 = scale * (Vectorized<float>(float_val0) - zero_point);
auto val1 = scale * (Vectorized<float>(float_val1) - zero_point);
auto val2 = scale * (Vectorized<float>(float_val2) - zero_point);
auto val3 = scale * (Vectorized<float>(float_val3) - zero_point);
#endif
return {val0, val1, val2, val3}; return {val0, val1, val2, val3};
} }
@ -607,39 +478,11 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
} }
Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const { Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_max_epi8(vals, b.vals); return _mm256_max_epi8(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<int8_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int8_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<int8_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::max<int8_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const { Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epi8(vals, b.vals); return _mm256_min_epi8(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<int8_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int8_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<int8_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int8_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const { Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
@ -649,29 +492,11 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
Vectorized<c10::qint8> relu6( Vectorized<c10::qint8> relu6(
Vectorized<c10::qint8> zero_point, Vectorized<c10::qint8> zero_point,
Vectorized<c10::qint8> q_six) { Vectorized<c10::qint8> q_six) {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epi8( return _mm256_min_epi8(
_mm256_max_epi8(vals, zero_point.vals), q_six.vals); _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
#else
// Pray the compiler can autovectorize this
std::array<int8_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<int8_t, size()> zero_point_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
std::array<int8_t, size()> q_six_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
std::array<int8_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<int8_t>(
std::max<int8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const { int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
#ifdef CPU_CAPABILITY_AVX2
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
@ -701,55 +526,15 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
Vectorized<c10::qint32>(res_1), Vectorized<c10::qint32>(res_1),
Vectorized<c10::qint32>(res_2), Vectorized<c10::qint32>(res_2),
Vectorized<c10::qint32>(res_3)}; Vectorized<c10::qint32>(res_3)};
#else
// Pray the compiler can autovectorize this
std::array<int8_t, size()> int_vals;
store(int_vals.data());
std::array<int8_t, size()> b_vals;
b.store(b_vals.data());
constexpr int elem_per_int_vec = size() / int_num_vecs();
int32_t rv[int_num_vecs()][elem_per_int_vec];
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
rv[i][j] = static_cast<int32_t>(int_vals[i * elem_per_int_vec + j]) -
static_cast<int32_t>(b_vals[i * elem_per_int_vec + j]);
}
}
return {Vectorized<c10::qint32>::loadu(rv[0]),
Vectorized<c10::qint32>::loadu(rv[1]),
Vectorized<c10::qint32>::loadu(rv[2]),
Vectorized<c10::qint32>::loadu(rv[3])};
#endif
} }
static Vectorized<c10::qint8> requantize_from_int( static Vectorized<c10::qint8> requantize_from_int(
const int_vec_return_type& inp, const int_vec_return_type& inp,
float multiplier, float multiplier,
int32_t zero_point) { int32_t zero_point) {
#ifdef CPU_CAPABILITY_AVX2
__m256 multiplier_v = _mm256_set1_ps(multiplier); __m256 multiplier_v = _mm256_set1_ps(multiplier);
__m256i zero_point_v = _mm256_set1_epi32(zero_point); __m256i zero_point_v = _mm256_set1_epi32(zero_point);
return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v); return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
#else
// Pray the compiler can autovectorize this
constexpr int elem_per_int_vec = size() / int_num_vecs();
constexpr auto min_val = std::numeric_limits<value_type>::min();
constexpr auto max_val = std::numeric_limits<value_type>::max();
int32_t rv[int_num_vecs()][elem_per_int_vec];
for (size_t i = 0; i < int_num_vecs(); ++i) {
inp[i].store(rv[i]);
}
std::array<int8_t, size()> result_vals;
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
int32_t rounded =
nearbyint(static_cast<float>(rv[i][j]) * multiplier) + zero_point;
result_vals[i * elem_per_int_vec + j] =
std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
}
}
return loadu(result_vals.data());
#endif
} }
void dump() const { void dump() const {
@ -817,20 +602,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
private: private:
__m256i cvtepu8_epi32(__m128i epu8_vals) const { __m256i cvtepu8_epi32(__m128i epu8_vals) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_cvtepu8_epi32(epu8_vals); return _mm256_cvtepu8_epi32(epu8_vals);
#else // CPU_CAPABILITY_AVX2
__m128i result_data[2];
__m128i zeros = _mm_setzero_si128();
__m128i unpacked1 = _mm_unpacklo_epi8(epu8_vals, zeros);
__m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, zeros);
result_data[0] = unpacked2;
__m128i shifted = _mm_srli_si128(epu8_vals, 4);
__m128i unpacked3 = _mm_unpacklo_epi8(shifted, zeros);
__m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, zeros);
result_data[1] = unpacked4;
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data));
#endif
} }
public: public:
@ -848,7 +620,6 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
#if defined(CPU_CAPABILITY_AVX2)
auto val0 = auto val0 =
vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul); vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul);
auto val1 = auto val1 =
@ -857,12 +628,6 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul); vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul);
auto val3 = auto val3 =
vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul); vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul);
#else
auto val0 = scale * (Vectorized<float>(float_val0) - zero_point);
auto val1 = scale * (Vectorized<float>(float_val1) - zero_point);
auto val2 = scale * (Vectorized<float>(float_val2) - zero_point);
auto val3 = scale * (Vectorized<float>(float_val3) - zero_point);
#endif
return {val0, val1, val2, val3}; return {val0, val1, val2, val3};
} }
@ -879,39 +644,11 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
} }
Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const { Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_max_epu8(vals, b.vals); return _mm256_max_epu8(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<uint8_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<uint8_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<uint8_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::max<uint8_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const { Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epu8(vals, b.vals); return _mm256_min_epu8(vals, b.vals);
#else
// Pray the compiler can autovectorize this
std::array<uint8_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<uint8_t, size()> b_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&b_vals), b.vals);
std::array<uint8_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<uint8_t>(int_vals[i], b_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const { Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
@ -921,29 +658,11 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
Vectorized<c10::quint8> relu6( Vectorized<c10::quint8> relu6(
Vectorized<c10::quint8> zero_point, Vectorized<c10::quint8> zero_point,
Vectorized<c10::quint8> q_six) { Vectorized<c10::quint8> q_six) {
#ifdef CPU_CAPABILITY_AVX2
return _mm256_min_epu8( return _mm256_min_epu8(
_mm256_max_epu8(vals, zero_point.vals), q_six.vals); _mm256_max_epu8(vals, zero_point.vals), q_six.vals);
#else
// Pray the compiler can autovectorize this
std::array<uint8_t, size()> int_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
std::array<uint8_t, size()> zero_point_vals;
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
std::array<uint8_t, size()> q_six_vals;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
std::array<uint8_t, size()> result_vals;
for (size_t i = 0; i < size(); ++i) {
result_vals[i] = std::min<uint8_t>(
std::max<uint8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
}
return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
#endif
} }
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const { int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
#ifdef CPU_CAPABILITY_AVX2
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
@ -972,55 +691,15 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
Vectorized<c10::qint32>(res_1), Vectorized<c10::qint32>(res_1),
Vectorized<c10::qint32>(res_2), Vectorized<c10::qint32>(res_2),
Vectorized<c10::qint32>(res_3)}; Vectorized<c10::qint32>(res_3)};
#else
// Pray the compiler can autovectorize this
std::array<uint8_t, size()> int_vals;
std::array<uint8_t, size()> b_vals;
store(int_vals.data());
b.store(b_vals.data());
static constexpr int elem_per_int_vec = size() / int_num_vecs();
int32_t rv[int_num_vecs()][elem_per_int_vec];
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
rv[i][j] = static_cast<int32_t>(int_vals[i * elem_per_int_vec + j]) -
static_cast<int32_t>(b_vals[i * elem_per_int_vec + j]);
}
}
return {Vectorized<c10::qint32>::loadu(rv[0]),
Vectorized<c10::qint32>::loadu(rv[1]),
Vectorized<c10::qint32>::loadu(rv[2]),
Vectorized<c10::qint32>::loadu(rv[3])};
#endif
} }
static Vectorized<c10::quint8> requantize_from_int( static Vectorized<c10::quint8> requantize_from_int(
const int_vec_return_type& inp, const int_vec_return_type& inp,
float multiplier, float multiplier,
int32_t zero_point) { int32_t zero_point) {
#ifdef CPU_CAPABILITY_AVX2
__m256 multiplier_v = _mm256_set1_ps(multiplier); __m256 multiplier_v = _mm256_set1_ps(multiplier);
__m256i zero_point_v = _mm256_set1_epi32(zero_point); __m256i zero_point_v = _mm256_set1_epi32(zero_point);
return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v); return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
#else
// Pray the compiler can autovectorize this
constexpr int elem_per_int_vec = size() / int_num_vecs();
constexpr auto min_val = std::numeric_limits<value_type>::min();
constexpr auto max_val = std::numeric_limits<value_type>::max();
int32_t rv[int_num_vecs()][elem_per_int_vec];
for (size_t i = 0; i < int_num_vecs(); ++i) {
inp[i].store(rv[i]);
}
std::array<uint8_t, size()> result_vals;
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
int32_t rounded =
nearbyint(static_cast<float>(rv[i][j]) * multiplier) + zero_point;
result_vals[i * elem_per_int_vec + j] =
std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
}
}
return loadu(result_vals.data());
#endif
} }
void dump() const { void dump() const {
@ -1497,6 +1176,5 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
return a.maximum(b); return a.maximum(b);
} }
#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) #endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
}}} }}}

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h> #include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h> #include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/complex.h> #include <c10/util/complex.h>
@ -141,7 +141,7 @@ class Vectorized<ComplexDbl> {
vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return { return {
@ -153,7 +153,7 @@ class Vectorized<ComplexDbl> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values)); vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values));
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values)); vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values));
std::memcpy( std::memcpy(
@ -165,7 +165,7 @@ class Vectorized<ComplexDbl> {
ComplexDbl& operator[](int idx) = delete; ComplexDbl& operator[](int idx) = delete;
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const { Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const {
__at_align32__ ComplexDbl tmp[size()]; __at_align__ ComplexDbl tmp[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -174,7 +174,7 @@ class Vectorized<ComplexDbl> {
} }
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const { Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const {
__at_align32__ ComplexDbl tmp[size()]; __at_align__ ComplexDbl tmp[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -455,8 +455,8 @@ class Vectorized<ComplexDbl> {
} }
Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const { Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
__at_align32__ ComplexDbl x_tmp[size()]; __at_align__ ComplexDbl x_tmp[size()];
__at_align32__ ComplexDbl y_tmp[size()]; __at_align__ ComplexDbl y_tmp[size()];
store(x_tmp); store(x_tmp);
exp.store(y_tmp); exp.store(y_tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/complex.h> #include <c10/util/complex.h>
@ -196,7 +196,7 @@ class Vectorized<ComplexFlt> {
vec_vsx_ld(offset16, reinterpret_cast<const float*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const float*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return { return {
@ -209,7 +209,7 @@ class Vectorized<ComplexFlt> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(tmp_values)); vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(tmp_values));
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(tmp_values)); vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(tmp_values));
std::memcpy( std::memcpy(
@ -221,7 +221,7 @@ class Vectorized<ComplexFlt> {
ComplexFlt& operator[](int idx) = delete; ComplexFlt& operator[](int idx) = delete;
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(ComplexFlt)) const { Vectorized<ComplexFlt> map(ComplexFlt (*const f)(ComplexFlt)) const {
__at_align32__ ComplexFlt tmp[size()]; __at_align__ ComplexFlt tmp[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -230,7 +230,7 @@ class Vectorized<ComplexFlt> {
} }
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(const ComplexFlt&)) const { Vectorized<ComplexFlt> map(ComplexFlt (*const f)(const ComplexFlt&)) const {
__at_align32__ ComplexFlt tmp[size()]; __at_align__ ComplexFlt tmp[size()];
store(tmp); store(tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]); tmp[i] = f(tmp[i]);
@ -434,8 +434,8 @@ class Vectorized<ComplexFlt> {
} }
Vectorized<ComplexFlt> pow(const Vectorized<ComplexFlt>& exp) const { Vectorized<ComplexFlt> pow(const Vectorized<ComplexFlt>& exp) const {
__at_align32__ ComplexFlt x_tmp[size()]; __at_align__ ComplexFlt x_tmp[size()];
__at_align32__ ComplexFlt y_tmp[size()]; __at_align__ ComplexFlt y_tmp[size()];
store(x_tmp); store(x_tmp);
exp.store(y_tmp); exp.store(y_tmp);
for (int i = 0; i < size(); i++) { for (int i = 0; i < size(); i++) {

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <sleef.h> #include <sleef.h>
@ -169,7 +169,7 @@ class Vectorized<double> {
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
@ -179,7 +179,7 @@ class Vectorized<double> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <sleef.h> #include <sleef.h>
namespace at { namespace at {
@ -180,7 +180,7 @@ class Vectorized<float> {
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
@ -190,7 +190,7 @@ class Vectorized<float> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
namespace at { namespace at {
namespace vec { namespace vec {
@ -269,7 +269,7 @@ class Vectorized<int16_t> {
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
@ -279,7 +279,7 @@ class Vectorized<int16_t> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
namespace at { namespace at {
namespace vec { namespace vec {
@ -199,7 +199,7 @@ class Vectorized<int32_t> {
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
@ -209,7 +209,7 @@ class Vectorized<int32_t> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
namespace at { namespace at {
namespace vec { namespace vec {
@ -148,7 +148,7 @@ class Vectorized<int64_t> {
(vint64)vec_vsx_ld(offset16, dptr)}; (vint64)vec_vsx_ld(offset16, dptr)};
} }
__at_align32__ double tmp_values[size()]; __at_align__ double tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return { return {
@ -161,7 +161,7 @@ class Vectorized<int64_t> {
vec_vsx_st((vfloat64)_vec0, offset0, dptr); vec_vsx_st((vfloat64)_vec0, offset0, dptr);
vec_vsx_st((vfloat64)_vec1, offset16, dptr); vec_vsx_st((vfloat64)_vec1, offset16, dptr);
} else if (count > 0) { } else if (count > 0) {
__at_align32__ double tmp_values[size()]; __at_align__ double tmp_values[size()];
vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); vec_vsx_st((vfloat64)_vec0, offset0, tmp_values);
vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); vec_vsx_st((vfloat64)_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/qint32.h> #include <c10/util/qint32.h>
#include <array> #include <array>
@ -81,7 +81,7 @@ struct Vectorized<c10::qint32> {
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
@ -91,7 +91,7 @@ struct Vectorized<c10::qint32> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/qint8.h> #include <c10/util/qint8.h>
#include <array> #include <array>
@ -91,7 +91,7 @@ struct Vectorized<c10::qint8> {
vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)), vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
} }
@ -100,7 +100,7 @@ struct Vectorized<c10::qint8> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_base.h> #include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/quint8.h> #include <c10/util/quint8.h>
#include <array> #include <array>
@ -92,7 +92,7 @@ struct Vectorized<c10::quint8> {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)), vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
} }
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
} }
@ -101,7 +101,7 @@ struct Vectorized<c10::quint8> {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) { } else if (count > 0) {
__at_align32__ value_type tmp_values[size()]; __at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy( std::memcpy(

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char;
using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short;

View File

@ -0,0 +1,195 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec512/vec512_float.h>
#include <ATen/cpu/vec/vec512/vec512_bfloat16.h>
#include <ATen/cpu/vec/vec512/vec512_double.h>
#include <ATen/cpu/vec/vec512/vec512_int.h>
#include <ATen/cpu/vec/vec512/vec512_qint.h>
#include <ATen/cpu/vec/vec512/vec512_complex_float.h>
#include <ATen/cpu/vec/vec512/vec512_complex_double.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
namespace at {
namespace vec {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
stream << val.val_;
return stream;
}
C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
stream << static_cast<int>(val.val_);
return stream;
}
C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
stream << static_cast<unsigned int>(val.val_);
return stream;
}
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
T buf[Vectorized<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != Vectorized<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
return _mm512_castpd_ps(src);
}
template<>
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return _mm512_castps_pd(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
return _mm512_i64gather_pd(vindex, base_addr, scale);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm512_i32gather_ps(vindex, base_addr, scale);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
const Vectorized<int64_t>& vindex, const Vectorized<double>& mask) {
auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF));
auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ);
return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex, const Vectorized<float>& mask) {
auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF));
auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
Vectorized<int64_t>
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
return _mm512_cvtpd_epi64(src);
}
template<>
Vectorized<int32_t>
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return _mm512_cvttps_epi32(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, a1, a3, a3, a4, a5, a6, a7}
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
__m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0);
__m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4);
return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
// b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
//
// return:
// {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
__m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4,
19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12,
27, 11, 26, 10, 25, 9, 24, 8);
return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
_mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
// output:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
// The members of indices have been written in binary format for better understandability
__m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
__m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
// output:
// return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
// {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
__m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16,
14, 12, 10, 8, 6, 4, 2, 0);
__m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17,
15, 13, 11, 9, 7, 5, 3, 1);
return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
_mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
}
#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
}}}

View File

@ -0,0 +1,879 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
}
static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
__m256i lo = _mm512_extracti32x8_epi32(a, 0);
__m256i hi = _mm512_extracti32x8_epi32(a, 1);
cvtbf16_fp32(lo, o1);
cvtbf16_fp32(hi, o2);
}
static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
__m512i lo = _mm512_castps_si512(a);
__m512i hi = _mm512_castps_si512(b);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones);
auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_lo = _mm512_add_epi32(t_lo, vec_bias);
t_hi = _mm512_add_epi32(t_hi, vec_bias);
// input += rounding_bias;
t_lo = _mm512_add_epi32(t_lo, lo);
t_hi = _mm512_add_epi32(t_hi, hi);
// input = input >> 16;
t_lo = _mm512_srli_epi32(t_lo, 16);
t_hi = _mm512_srli_epi32(t_hi, 16);
// Check NaN before converting back to bf16
t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo);
t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi);
t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
return _mm512_permutexvar_epi64(idx, t_lo);
}
static inline __m512i merge_compare_result(const __m512& a, const __m512& b) {
__m512i lo = _mm512_castps_si512(a);
__m512i hi = _mm512_castps_si512(b);
lo = _mm512_srli_epi32(lo, 16);
hi = _mm512_srli_epi32(hi, 16);
auto out = _mm512_packus_epi32(lo, hi);
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
return _mm512_permutexvar_epi64(idx, out);
}
template <> class Vectorized<BFloat16> {
private:
__m512i values;
public:
using value_type = uint16_t;
using size_type = int;
static constexpr size_type size() {
return 32;
}
Vectorized() {}
Vectorized(__m512i v) : values(v) {}
Vectorized(BFloat16 val) {
value_type uw = val.x;
values = _mm512_set1_epi16(uw);
}
Vectorized(BFloat16 val1, BFloat16 val2, BFloat16 val3, BFloat16 val4,
BFloat16 val5, BFloat16 val6, BFloat16 val7, BFloat16 val8,
BFloat16 val9, BFloat16 val10, BFloat16 val11, BFloat16 val12,
BFloat16 val13, BFloat16 val14, BFloat16 val15, BFloat16 val16,
BFloat16 val17, BFloat16 val18, BFloat16 val19, BFloat16 val20,
BFloat16 val21, BFloat16 val22, BFloat16 val23, BFloat16 val24,
BFloat16 val25, BFloat16 val26, BFloat16 val27, BFloat16 val28,
BFloat16 val29, BFloat16 val30, BFloat16 val31, BFloat16 val32) {
values = _mm512_set_epi16(
val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x,
val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x,
val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x,
val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x);
}
operator __m512i() const {
return values;
}
BFloat16& operator[](int idx) = delete;
const BFloat16& operator[](int idx) const = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
}
static Vectorized<BFloat16> loadu(const void* ptr) {
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
}
static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) {
__at_align__ int16_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
return loadu(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
} else if (count > 0) {
__at_align__ int16_t tmp_values[size()];
_mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
}
}
template <int64_t mask>
static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
__at_align__ int16_t tmp_values[size()];
a.store(tmp_values);
if (mask & 0x01)
tmp_values[0] = b.values[31];
if (mask & 0x02)
tmp_values[1] = b.values[30];
if (mask & 0x04)
tmp_values[2] = b.values[29];
if (mask & 0x08)
tmp_values[3] = b.values[28];
if (mask & 0x10)
tmp_values[4] = b.values[27];
if (mask & 0x20)
tmp_values[5] = b.values[26];
if (mask & 0x40)
tmp_values[6] = b.values[25];
if (mask & 0x80)
tmp_values[7] = b.values[24];
if (mask & 0x100)
tmp_values[8] = b.values[23];
if (mask & 0x200)
tmp_values[9] = b.values[22];
if (mask & 0x400)
tmp_values[10] = b.values[21];
if (mask & 0x800)
tmp_values[11] = b.values[20];
if (mask & 0x1000)
tmp_values[12] = b.values[19];
if (mask & 0x2000)
tmp_values[13] = b.values[18];
if (mask & 0x4000)
tmp_values[14] = b.values[17];
if (mask & 0x8000)
tmp_values[15] = b.values[16];
if (mask & 0x10000)
tmp_values[16] = b.values[15];
if (mask & 0x20000)
tmp_values[17] = b.values[14];
if (mask & 0x40000)
tmp_values[18] = b.values[13];
if (mask & 0x80000)
tmp_values[19] = b.values[12];
if (mask & 0x100000)
tmp_values[20] = b.values[11];
if (mask & 0x200000)
tmp_values[21] = b.values[10];
if (mask & 0x400000)
tmp_values[22] = b.values[9];
if (mask & 0x800000)
tmp_values[23] = b.values[8];
if (mask & 0x1000000)
tmp_values[24] = b.values[7];
if (mask & 0x2000000)
tmp_values[25] = b.values[6];
if (mask & 0x4000000)
tmp_values[26] = b.values[5];
if (mask & 0x8000000)
tmp_values[27] = b.values[4];
if (mask & 0x10000000)
tmp_values[28] = b.values[3];
if (mask & 0x20000000)
tmp_values[29] = b.values[2];
if (mask & 0x40000000)
tmp_values[30] = b.values[1];
if (mask & 0x80000000)
tmp_values[31] = b.values[0];
return loadu(tmp_values);
}
static Vectorized<BFloat16> blendv(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& mask) {
auto all_ones = _mm512_set1_epi16(0xFFFF);
auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_epi16(mask_, a.values, b.values);
}
template<typename step_t>
static Vectorized<BFloat16> arange(BFloat16 base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<BFloat16>(
base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
}
static Vectorized<BFloat16> set(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& b, int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
case 8:
return blend<255>(a, b);
case 9:
return blend<511>(a, b);
case 10:
return blend<1023>(a, b);
case 11:
return blend<2047>(a, b);
case 12:
return blend<4095>(a, b);
case 13:
return blend<8191>(a, b);
case 14:
return blend<16383>(a, b);
case 15:
return blend<32767>(a, b);
case 16:
return blend<65535>(a, b);
case 17:
return blend<131071>(a, b);
case 18:
return blend<262143>(a, b);
case 19:
return blend<524287>(a, b);
case 20:
return blend<1048575>(a, b);
case 21:
return blend<2097151>(a, b);
case 22:
return blend<4194303>(a, b);
case 23:
return blend<8388607>(a, b);
case 24:
return blend<16777215>(a, b);
case 25:
return blend<33554431>(a, b);
case 26:
return blend<67108863>(a, b);
case 27:
return blend<134217727>(a, b);
case 28:
return blend<268435455>(a, b);
case 29:
return blend<536870911>(a, b);
case 30:
return blend<1073741823>(a, b);
case 31:
return blend<2147483647>(a, b);
}
return b;
}
Vectorized<BFloat16> map(const __m512 (*const vop)(__m512)) const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
const auto o1 = vop(lo);
const auto o2 = vop(hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> abs() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
const auto mask = _mm512_set1_ps(-0.f);
const auto o1 = _mm512_andnot_ps(mask, lo);
const auto o2 = _mm512_andnot_ps(mask, hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> angle() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto angle_lambda = [](__m512 values) {
const auto zero_vec = _mm512_set1_ps(0.f);
const auto nan_vec = _mm512_set1_ps(NAN);
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
not_nan_mask, 0xFFFFFFFF);
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec),
zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_ps(c10::pi<float>);
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
return angle;
};
auto o1 = angle_lambda(lo);
auto o2 = angle_lambda(hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> real() const {
return *this;
}
Vectorized<BFloat16> imag() const {
return _mm512_set1_epi16(0);
}
Vectorized<BFloat16> conj() const {
return *this;
}
Vectorized<BFloat16> acos() const {
return map(Sleef_acosf16_u10);
}
Vectorized<BFloat16> asin() const {
return map(Sleef_asinf16_u10);
}
Vectorized<BFloat16> atan() const {
return map(Sleef_atanf16_u10);
}
Vectorized<BFloat16> atan2(const Vectorized<BFloat16> &b) const {
__m512 lo, hi;
__m512 b1, b2;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(b.values, b1, b2);
auto o1 = Sleef_atan2f16_u10(lo, b1);
auto o2 = Sleef_atan2f16_u10(hi, b2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> copysign(const Vectorized<BFloat16> &sign) const {
// copy sign bit (0x8000) from sign and remaining bits from values
__m512i mask_value = _mm512_set1_epi32(~0x80008000);
__m512i mask_signbit = _mm512_set1_epi32(0x80008000);
return Vectorized<BFloat16>(
_mm512_or_si512(
_mm512_and_si512(values, mask_value),
_mm512_and_si512(sign, mask_signbit)));
}
Vectorized<BFloat16> erf() const {
return map(Sleef_erff16_u10);
}
Vectorized<BFloat16> erfc() const {
return map(Sleef_erfcf16_u15);
}
Vectorized<BFloat16> erfinv() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_erfinv(tmp1[i]);
tmp2[i] = calc_erfinv(tmp2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> exp() const {
return map(Sleef_expf16_u10);
}
Vectorized<BFloat16> expm1() const {
return map(Sleef_expm1f16_u10);
}
Vectorized<BFloat16> fmod(const Vectorized<BFloat16> & q) const {
__m512 x_lo, x_hi;
cvtbf16_fp32(values, x_lo, x_hi);
__m512 q_lo, q_hi;
cvtbf16_fp32(q.values, q_lo, q_hi);
auto o1 = Sleef_fmodf16(x_lo, q_lo);
auto o2 = Sleef_fmodf16(x_hi, q_hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> hypot(const Vectorized<BFloat16> &b) const {
__m512 lo, hi;
__m512 b1, b2;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(b.values, b1, b2);
auto o1 = Sleef_hypotf16_u05(lo, b1);
auto o2 = Sleef_hypotf16_u05(hi, b2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> i0() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_i0(tmp1[i]);
tmp2[i] = calc_i0(tmp2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> i0e() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
constexpr auto sz = size();
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (auto i = decltype(sz){0}; i < sz / 2; i++) {
tmp1[i] = calc_i0e(tmp1[i]);
tmp2[i] = calc_i0e(tmp2[i]);
}
const auto o1 = _mm512_loadu_ps(tmp1);
const auto o2 = _mm512_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> igamma(const Vectorized<BFloat16> &x) const {
__m512 lo, hi;
__m512 xlo, xhi;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(x.values, xlo, xhi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> igammac(const Vectorized<BFloat16> &x) const {
__m512 lo, hi;
__m512 xlo, xhi;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(x.values, xlo, xhi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> log() const {
return map(Sleef_logf16_u10);
}
Vectorized<BFloat16> log2() const {
return map(Sleef_log2f16_u10);
}
Vectorized<BFloat16> log10() const {
return map(Sleef_log10f16_u10);
}
Vectorized<BFloat16> log1p() const {
return map(Sleef_log1pf16_u10);
}
Vectorized<BFloat16> frac() const;
Vectorized<BFloat16> sin() const {
return map(Sleef_sinf16_u10);
}
Vectorized<BFloat16> sinh() const {
return map(Sleef_sinhf16_u10);
}
Vectorized<BFloat16> cos() const {
return map(Sleef_cosf16_u10);
}
Vectorized<BFloat16> cosh() const {
return map(Sleef_coshf16_u10);
}
Vectorized<BFloat16> ceil() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto o1 = _mm512_ceil_ps(lo);
auto o2 = _mm512_ceil_ps(hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> floor() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto o1 = _mm512_floor_ps(lo);
auto o2 = _mm512_floor_ps(hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> neg() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto mask = _mm512_set1_ps(-0.f);
auto o1 = _mm512_xor_ps(mask, lo);
auto o2 = _mm512_xor_ps(mask, hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> round() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> tan() const {
return map(Sleef_tanf16_u10);
}
Vectorized<BFloat16> tanh() const {
return map(Sleef_tanhf16_u10);
}
Vectorized<BFloat16> trunc() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> lgamma() const {
return map(Sleef_lgammaf16_u10);
}
Vectorized<BFloat16> sqrt() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto o1 = _mm512_sqrt_ps(lo);
auto o2 = _mm512_sqrt_ps(hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> reciprocal() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto ones = _mm512_set1_ps(1);
auto o1 = _mm512_div_ps(ones, lo);
auto o2 = _mm512_div_ps(ones, hi);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> rsqrt() const {
__m512 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto ones = _mm512_set1_ps(1);
auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> pow(const Vectorized<BFloat16> &b) const {
__m512 lo, hi;
__m512 b1, b2;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(b.values, b1, b2);
auto o1 = Sleef_powf16_u10(lo, b1);
auto o2 = Sleef_powf16_u10(hi, b2);
return cvtfp32_bf16(o1, o2);
}
Vectorized<BFloat16> inline operator>(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> inline operator<(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> inline operator>=(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> inline operator<=(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> inline operator==(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> inline operator!=(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
};
template<typename Op>
Vectorized<BFloat16> static inline bfloat16_binary_op_as_fp32(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& b, Op op) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
auto o1 = op(a_lo, b_lo);
auto o2 = op(a_hi, b_hi);
return cvtfp32_bf16(o1, o2);
}
template<typename Op>
Vectorized<BFloat16> static inline bfloat16_compare_as_fp32(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& b, Op op) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
auto o1 = op(a_lo, b_lo);
auto o2 = op(a_hi, b_hi);
return merge_compare_result(o1, o2);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>=(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<=(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
}
Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
}
Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
}
Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
}
Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return _mm512_and_si512(a, b);
}
Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return _mm512_or_si512(a, b);
}
Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return _mm512_xor_si512(a, b);
}
Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const {
return (*this == other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const {
return (*this != other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const {
return (*this > other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const {
return (*this >= other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const {
return (*this < other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const {
return (*this <= other) & Vectorized<BFloat16>(1.0f);
}
// frac. Implement this here so we can use subtraction
Vectorized<BFloat16> Vectorized<BFloat16>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
auto max_lo = _mm512_max_ps(a_lo, b_lo);
auto max_hi = _mm512_max_ps(a_hi, b_hi);
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
// Exploit the fact that all-ones is a NaN.
auto o1 = _mm512_or_ps(max_lo, nan_lo);
auto o2 = _mm512_or_ps(max_hi, nan_hi);
return cvtfp32_bf16(o1, o2);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
__m512i zero_vec = _mm512_set1_epi32(0);
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
auto min_lo = _mm512_min_ps(a_lo, b_lo);
auto min_hi = _mm512_min_ps(a_hi, b_hi);
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
0xFFFFFFFF));
auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
auto o1 = _mm512_or_ps(min_lo, nan_lo);
auto o2 = _mm512_or_ps(min_hi, nan_hi);
return cvtfp32_bf16(o1, o2);
}
template <>
Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
__m512 a_lo, a_hi;
__m512 min_lo, min_hi;
__m512 max_lo, max_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(min), min_lo, min_hi);
cvtbf16_fp32(__m512i(max), max_lo, max_hi);
auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
return cvtfp32_bf16(o1, o2);
}
template <>
Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) {
__m512 a_lo, a_hi;
__m512 max_lo, max_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(max), max_lo, max_hi);
auto o1 = _mm512_min_ps(max_lo, a_lo);
auto o2 = _mm512_min_ps(max_hi, a_hi);
return cvtfp32_bf16(o1, o2);
}
template <>
Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) {
__m512 a_lo, a_hi;
__m512 min_lo, min_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(min), min_lo, min_hi);
auto o1 = _mm512_max_ps(min_lo, a_lo);
auto o2 = _mm512_max_ps(min_hi, a_hi);
return cvtfp32_bf16(o1, o2);
}
template <>
inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
int64_t i;
#pragma unroll
for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) {
auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
_mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
}
#pragma unroll
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
__m512 c_lo, c_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
cvtbf16_fp32(__m512i(c), c_lo, c_hi);
auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
return cvtfp32_bf16(o1, o2);
}
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
__m512 o1, o2;
cvtbf16_fp32(__m512i(a), o1, o2);
return std::make_tuple(o1, o2);
}
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
return cvtfp32_bf16(__m512(a), __m512(b));
}
#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
constexpr int64_t K = Vectorized<BFloat16>::size();
__at_align__ float arr[K];
__at_align__ BFloat16 arr2[K];
a.store(arr2);
convert(arr2, arr, K);
return std::make_tuple(
Vectorized<float>::loadu(arr),
Vectorized<float>::loadu(arr + Vectorized<float>::size()));
}
inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
constexpr int64_t K = Vectorized<BFloat16>::size();
__at_align__ float arr[K];
__at_align__ BFloat16 arr2[K];
a.store(arr);
b.store(arr + Vectorized<float>::size());
convert(arr, arr2, K);
return Vectorized<BFloat16>::loadu(arr2);
}
#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data));
__m512 out_values;
cvtbf16_fp32(values, out_values);
out = out_values;
}
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) {
auto vec = Vectorized<c10::BFloat16>::loadu(data);
__m512 out1_values, out2_values;
cvtbf16_fp32(vec, out1_values, out2_values);
out1 = out1_values;
out2 = out2_values;
}
#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
for (int k = 0; k < Vectorized<float>::size(); ++k) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
}
void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) {
load_fp32_from_bf16(data, out1);
data += Vectorized<float>::size();
load_fp32_from_bf16(data, out2);
}
#endif
}}}

View File

@ -0,0 +1,526 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <c10/util/complex.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
template <> class Vectorized<c10::complex<double>> {
private:
__m512d values;
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
public:
using value_type = c10::complex<double>;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
Vectorized(__m512d v) : values(v) {}
Vectorized(c10::complex<double> val) {
double real_value = val.real();
double imag_value = val.imag();
values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value,
real_value, imag_value, real_value, imag_value);
}
Vectorized(c10::complex<double> val1, c10::complex<double> val2,
c10::complex<double> val3, c10::complex<double> val4) {
values = _mm512_setr_pd(val1.real(), val1.imag(),
val2.real(), val2.imag(),
val3.real(), val3.imag(),
val4.real(), val4.imag());
}
operator __m512d() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::complex<double>> blend(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
// NOLINTNEXTLINE(clang-diagnostic-warning)
switch (mask) {
case 0:
return a;
case 1:
return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011
case 2:
return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100
case 3:
return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111
case 4:
return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000
case 5:
return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011
case 6:
return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100
case 7:
return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111
case 8:
return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000
case 9:
return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011
case 10:
return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100
case 11:
return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111
case 12:
return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000
case 13:
return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011
case 14:
return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100
case 15:
return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111
}
return b;
}
static Vectorized<c10::complex<double>> blendv(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b,
const Vectorized<c10::complex<double>>& mask) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values);
auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_pd(mmask, a.values, b.values);
}
template<typename step_t>
static Vectorized<c10::complex<double>> arange(c10::complex<double> base = 0.,
step_t step = static_cast<step_t>(1)) {
return Vectorized<c10::complex<double>>(base,
base + c10::complex<double>(1)*step,
base + c10::complex<double>(2)*step,
base + c10::complex<double>(3)*step);
}
static Vectorized<c10::complex<double>> set(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<c10::complex<double>> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm512_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align__ double tmp_values[2*size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (auto i = 0; i < 2*size(); ++i) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const double*>(ptr),
count * sizeof(c10::complex<double>));
return _mm512_load_pd(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
double tmp_values[2*size()];
_mm512_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<double>));
}
}
const c10::complex<double>& operator[](int idx) const = delete;
c10::complex<double>& operator[](int idx) = delete;
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
__at_align__ c10::complex<double> tmp[size()];
store(tmp);
for (int i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
// AVX512 doesn't have horizontal add & horizontal sub instructions.
// TODO: hadd_pd() & hsub_pd() may have scope for improvement.
static inline __m512d hadd_pd(__m512d a, __m512d b) {
__m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
__m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
static inline __m512d hsub_pd(__m512d a, __m512d b) {
__m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
__m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
__m512d abs_2_() const {
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
}
__m512d abs_() const {
return _mm512_sqrt_pd(abs_2_()); // abs abs
}
Vectorized<c10::complex<double>> abs() const {
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm512_and_pd(abs_(), real_mask); // abs 0
}
__m512d angle_() const {
//angle = atan2(b/a)
auto b_a = _mm512_permute_pd(values, 0x55); // b a
return Sleef_atan2d8_u10(values, b_a); // 90-angle angle
}
Vectorized<c10::complex<double>> angle() const {
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle
return _mm512_and_pd(angle, real_mask); // angle 0
}
Vectorized<c10::complex<double>> sgn() const {
auto abs = abs_();
auto zero = _mm512_setzero_pd();
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
auto mask_vec = _mm512_mask_set1_epi64(_mm512_castpd_si512(zero), mask,
0xFFFFFFFFFFFFFFFF);
auto abs_val = Vectorized(abs);
auto div = values / abs_val.values; // x / abs(x)
return blendv(div, zero, _mm512_castsi512_pd(mask_vec));
}
__m512d real_() const {
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm512_and_pd(values, real_mask);
}
Vectorized<c10::complex<double>> real() const {
return real_();
}
__m512d imag_() const {
const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
return _mm512_and_pd(values, imag_mask);
}
Vectorized<c10::complex<double>> imag() const {
return _mm512_permute_pd(imag_(), 0x55); //b a
}
__m512d conj_() const {
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
return _mm512_xor_pd(values, sign_mask); // a -b
}
Vectorized<c10::complex<double>> conj() const {
return conj_();
}
Vectorized<c10::complex<double>> log() const {
// Most trigonomic ops use the log() op to improve complex number performance.
return map(std::log);
}
Vectorized<c10::complex<double>> log2() const {
const __m512d log2_ = _mm512_set1_pd(std::log(2));
return _mm512_div_pd(log(), log2_);
}
Vectorized<c10::complex<double>> log10() const {
const __m512d log10_ = _mm512_set1_pd(std::log(10));
return _mm512_div_pd(log(), log10_);
}
Vectorized<c10::complex<double>> log1p() const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> asin() const {
// asin(x)
// = -i*ln(iz + sqrt(1 -z^2))
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
const __m512d one = _mm512_set1_pd(1);
auto conj = conj_();
auto b_a = _mm512_permute_pd(conj, 0x55); //-b a
auto ab = _mm512_mul_pd(conj, b_a); //-ab -ab
auto im = _mm512_add_pd(ab, ab); //-2ab -2ab
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b b*b-a*a
re = _mm512_sub_pd(one, re);
auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); //sqrt(re + i*im)
auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); //ln(iz + sqrt())
return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj(); //-i*ln()
}
Vectorized<c10::complex<double>> acos() const {
// acos(x) = pi/2 - asin(x)
constexpr auto pi_2d = c10::pi<double> / 2;
const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0);
return _mm512_sub_pd(pi_2, asin());
}
Vectorized<c10::complex<double>> atan() const;
Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> erf() const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> erfc() const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> exp() const {
//exp(a + bi)
// = exp(a)*(cos(b) + sin(b)i)
auto exp = Sleef_expd8_u10(values); //exp(a) exp(b)
exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) exp(a)
auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55),
sin_cos.x); //cos(b) sin(b)
return _mm512_mul_pd(exp, cos_sin);
}
Vectorized<c10::complex<double>> expm1() const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);
}
Vectorized<c10::complex<double>> sinh() const {
return map(std::sinh);
}
Vectorized<c10::complex<double>> cos() const {
return map(std::cos);
}
Vectorized<c10::complex<double>> cosh() const {
return map(std::cosh);
}
Vectorized<c10::complex<double>> ceil() const {
return _mm512_ceil_pd(values);
}
Vectorized<c10::complex<double>> floor() const {
return _mm512_floor_pd(values);
}
Vectorized<c10::complex<double>> hypot(const Vectorized<c10::complex<double>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> igamma(const Vectorized<c10::complex<double>> &x) const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> igammac(const Vectorized<c10::complex<double>> &x) const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> neg() const {
auto zero = _mm512_setzero_pd();
return _mm512_sub_pd(zero, values);
}
Vectorized<c10::complex<double>> nextafter(const Vectorized<c10::complex<double>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vectorized<c10::complex<double>> round() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<double>> tan() const {
return map(std::tan);
}
Vectorized<c10::complex<double>> tanh() const {
return map(std::tanh);
}
Vectorized<c10::complex<double>> trunc() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<double>> sqrt() const {
return map(std::sqrt);
}
Vectorized<c10::complex<double>> reciprocal() const;
Vectorized<c10::complex<double>> rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
__at_align__ c10::complex<double> x_tmp[size()];
__at_align__ c10::complex<double> y_tmp[size()];
store(x_tmp);
exp.store(y_tmp);
for (int i = 0; i < size(); i++) {
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
}
return loadu(x_tmp);
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<c10::complex<double>> operator==(const Vectorized<c10::complex<double>>& other) const {
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const {
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<c10::complex<double>> operator<(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator<=(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>=(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> eq(const Vectorized<c10::complex<double>>& other) const;
Vectorized<c10::complex<double>> ne(const Vectorized<c10::complex<double>>& other) const;
Vectorized<c10::complex<double>> lt(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> le(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> gt(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> ge(const Vectorized<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
};
template <> Vectorized<c10::complex<double>> inline operator+(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
return _mm512_add_pd(a, b);
}
template <> Vectorized<c10::complex<double>> inline operator-(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
return _mm512_sub_pd(a, b);
}
template <> Vectorized<c10::complex<double>> inline operator*(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto ac_bd = _mm512_mul_pd(a, b); //ac bd
auto d_c = _mm512_permute_pd(b, 0x55); //d c
d_c = _mm512_xor_pd(sign_mask, d_c); //d -c
auto ad_bc = _mm512_mul_pd(a, d_c); //ad -bc
auto ret = Vectorized<c10::complex<double>>::hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc
return ret;
}
template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2()
//im = (bc - ad)/abs_2()
const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
auto ac_bd = _mm512_mul_pd(a, b); //ac bd
auto d_c = _mm512_permute_pd(b, 0x55); //d c
d_c = _mm512_xor_pd(sign_mask, d_c); //-d c
auto ad_bc = _mm512_mul_pd(a, d_c); //-ad bc
auto re_im = Vectorized<c10::complex<double>>::hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad
return _mm512_div_pd(re_im, b.abs_2_());
}
// reciprocal. Implement this here so we can use multiplication.
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
return _mm512_div_pd(c_d, abs_2_());
}
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
// atan(x) = i/2 * ln((i + z)/(i - z))
const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
auto sum = Vectorized(_mm512_add_pd(i, values)); // a 1+b
auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b
auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
return i_half*ln; // i/2*ln()
}
template <>
Vectorized<c10::complex<double>> inline maximum(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
auto zero_vec = _mm512_set1_epi64(0);
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ);
auto max = _mm512_mask_blend_pd(mask, a, b);
// Exploit the fact that all-ones is a NaN.
auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF);
return _mm512_or_pd(max, _mm512_castsi512_pd(isnan));
}
template <>
Vectorized<c10::complex<double>> inline minimum(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
auto zero_vec = _mm512_set1_epi64(0);
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ);
auto min = _mm512_mask_blend_pd(mask, a, b);
// Exploit the fact that all-ones is a NaN.
auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF);
return _mm512_or_pd(min, _mm512_castsi512_pd(isnan));
}
template <>
Vectorized<c10::complex<double>> inline operator&(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
return _mm512_and_pd(a, b);
}
template <>
Vectorized<c10::complex<double>> inline operator|(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
return _mm512_or_pd(a, b);
}
template <>
Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
return _mm512_xor_pd(a, b);
}
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
return (*this == other) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
}
Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
return (*this != other) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
}
#endif
}}}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,454 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
template <> class Vectorized<double> {
private:
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
public:
// values needs to be public for compilation with clang
// as vec512.h uses it
__m512d values;
using value_type = double;
using size_type = int;
static constexpr size_type size() {
return 8;
}
Vectorized() {}
Vectorized(__m512d v) : values(v) {}
Vectorized(double val) {
values = _mm512_set1_pd(val);
}
Vectorized(double val1, double val2, double val3, double val4,
double val5, double val6, double val7, double val8) {
values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8);
}
operator __m512d() const {
return values;
}
template <int64_t mask>
static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_mask_blend_pd(mask, a.values, b.values);
}
static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
const Vectorized<double>& mask) {
auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_pd(mmask, a.values, b.values);
}
template<typename step_t>
static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step,
base + 7 * step);
}
static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
}
return b;
}
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm512_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align__ double tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (auto i = 0; i < size(); ++i) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const double*>(ptr),
count * sizeof(double));
return _mm512_load_pd(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
double tmp_values[size()];
_mm512_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
__mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ);
return static_cast<int32_t>(cmp);
}
Vectorized<double> isnan() const {
auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> map(double (*const f)(double)) const {
__at_align__ double tmp[size()];
store(tmp);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<double> abs() const {
auto mask = _mm512_set1_pd(-0.f);
return _mm512_andnot_pd(mask, values);
}
Vectorized<double> angle() const {
const auto zero_vec = _mm512_castsi512_pd(zero_vector);
const auto nan_vec = _mm512_set1_pd(NAN);
const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ);
const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask,
0xFFFFFFFFFFFFFFFF);
const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan),
zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_pd(c10::pi<double>);
const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec);
return angle;
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return _mm512_set1_pd(0);
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> acos() const {
return Vectorized<double>(Sleef_acosd8_u10(values));
}
Vectorized<double> asin() const {
return Vectorized<double>(Sleef_asind8_u10(values));
}
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand8_u10(values));
}
Vectorized<double> atan2(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_atan2d8_u10(values, b));
}
Vectorized<double> copysign(const Vectorized<double> &sign) const {
return Vectorized<double>(Sleef_copysignd8(values, sign));
}
Vectorized<double> erf() const {
return Vectorized<double>(Sleef_erfd8_u10(values));
}
Vectorized<double> erfc() const {
return Vectorized<double>(Sleef_erfcd8_u15(values));
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> exp() const {
return Vectorized<double>(Sleef_expd8_u10(values));
}
Vectorized<double> expm1() const {
return Vectorized<double>(Sleef_expm1d8_u10(values));
}
Vectorized<double> fmod(const Vectorized<double>& q) const {
return Vectorized<double>(Sleef_fmodd8(values, q));
}
Vectorized<double> hypot(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_hypotd8_u05(values, b));
}
Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> igamma(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> igammac(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> log() const {
return Vectorized<double>(Sleef_logd8_u10(values));
}
Vectorized<double> log2() const {
return Vectorized<double>(Sleef_log2d8_u10(values));
}
Vectorized<double> log10() const {
return Vectorized<double>(Sleef_log10d8_u10(values));
}
Vectorized<double> log1p() const {
return Vectorized<double>(Sleef_log1pd8_u10(values));
}
Vectorized<double> sin() const {
return Vectorized<double>(Sleef_sind8_u10(values));
}
Vectorized<double> sinh() const {
return Vectorized<double>(Sleef_sinhd8_u10(values));
}
Vectorized<double> cos() const {
return Vectorized<double>(Sleef_cosd8_u10(values));
}
Vectorized<double> cosh() const {
return Vectorized<double>(Sleef_coshd8_u10(values));
}
Vectorized<double> ceil() const {
return _mm512_ceil_pd(values);
}
Vectorized<double> floor() const {
return _mm512_floor_pd(values);
}
Vectorized<double> frac() const;
Vectorized<double> neg() const {
return _mm512_xor_pd(_mm512_set1_pd(-0.), values);
}
Vectorized<double> nextafter(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_nextafterd8(values, b));
}
Vectorized<double> round() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<double> tan() const {
return Vectorized<double>(Sleef_tand8_u10(values));
}
Vectorized<double> tanh() const {
return Vectorized<double>(Sleef_tanhd8_u10(values));
}
Vectorized<double> trunc() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<double> lgamma() const {
return Vectorized<double>(Sleef_lgammad8_u10(values));
}
Vectorized<double> sqrt() const {
return _mm512_sqrt_pd(values);
}
Vectorized<double> reciprocal() const {
return _mm512_div_pd(_mm512_set1_pd(1), values);
}
Vectorized<double> rsqrt() const {
return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values));
}
Vectorized<double> pow(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_powd8_u10(values, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<double> operator==(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator!=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator<(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator<=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator>(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator>=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> eq(const Vectorized<double>& other) const;
Vectorized<double> ne(const Vectorized<double>& other) const;
Vectorized<double> lt(const Vectorized<double>& other) const;
Vectorized<double> le(const Vectorized<double>& other) const;
Vectorized<double> gt(const Vectorized<double>& other) const;
Vectorized<double> ge(const Vectorized<double>& other) const;
};
template <>
Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_add_pd(a, b);
}
template <>
Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_sub_pd(a, b);
}
template <>
Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_mul_pd(a, b);
}
template <>
Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_div_pd(a, b);
}
// frac. Implement this here so we can use subtraction.
Vectorized<double> Vectorized<double>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
auto zero_vec = _mm512_set1_epi64(0);
Vectorized<double> max = _mm512_max_pd(a, b);
auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_pd(max, isnan);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
auto zero_vec = _mm512_set1_epi64(0);
Vectorized<double> min = _mm512_min_pd(a, b);
auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_pd(min, isnan);
}
template <>
Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
return _mm512_min_pd(max, _mm512_max_pd(min, a));
}
template <>
Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
return _mm512_max_pd(min, a);
}
template <>
Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
return _mm512_min_pd(max, a);
}
template <>
Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_and_pd(a, b);
}
template <>
Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_or_pd(a, b);
}
template <>
Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_xor_pd(a, b);
}
Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
return (*this == other) & Vectorized<double>(1.0);
}
Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
return (*this != other) & Vectorized<double>(1.0);
}
Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
return (*this > other) & Vectorized<double>(1.0);
}
Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
return (*this >= other) & Vectorized<double>(1.0);
}
Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
return (*this < other) & Vectorized<double>(1.0);
}
Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
return (*this <= other) & Vectorized<double>(1.0);
}
template <>
inline void convert(const double* src, double* dst, int64_t n) {
int64_t i;
#pragma unroll
for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
_mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i));
}
#pragma unroll
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
return _mm512_fmadd_pd(a, b, c);
}
#endif
}}}

View File

@ -0,0 +1,469 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [Acceptable use of anonymous namespace in header]
namespace {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
template <> class Vectorized<float> {
private:
static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0};
public:
__m512 values;
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return 16;
}
Vectorized() {}
Vectorized(__m512 v) : values(v) {}
Vectorized(float val) {
values = _mm512_set1_ps(val);
}
Vectorized(float val1, float val2, float val3, float val4,
float val5, float val6, float val7, float val8,
float val9, float val10, float val11, float val12,
float val13, float val14, float val15, float val16) {
values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8,
val9, val10, val11, val12, val13, val14, val15, val16);
}
operator __m512() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_mask_blend_ps(mask, a.values, b.values);
}
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
const Vectorized<float>& mask) {
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_ps(mmask, a.values, b.values);
}
template<typename step_t>
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<float>(
base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
}
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
case 8:
return blend<255>(a, b);
case 9:
return blend<511>(a, b);
case 10:
return blend<1023>(a, b);
case 11:
return blend<2047>(a, b);
case 12:
return blend<4095>(a, b);
case 13:
return blend<8191>(a, b);
case 14:
return blend<16383>(a, b);
case 15:
return blend<32767>(a, b);
}
return b;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm512_loadu_ps(reinterpret_cast<const float*>(ptr));
__at_align__ float tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (auto i = 0; i < size(); ++i) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float));
return _mm512_loadu_ps(tmp_values);
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
_mm512_storeu_ps(reinterpret_cast<float*>(ptr), values);
} else if (count > 0) {
float tmp_values[size()];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
__mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ);
return static_cast<int32_t>(cmp);
}
Vectorized<float> isnan() const {
auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<float> abs() const {
auto mask = _mm512_set1_ps(-0.f);
return _mm512_andnot_ps(mask, values);
}
Vectorized<float> angle() const {
__m512 zero_vec = _mm512_set1_ps(0.f);
const auto nan_vec = _mm512_set1_ps(NAN);
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
not_nan_mask, 0xFFFFFFFF);
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec),
zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_ps(c10::pi<double>);
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
return angle;
}
Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
return _mm512_set1_ps(0);
}
Vectorized<float> conj() const {
return *this;
}
Vectorized<float> acos() const {
return Vectorized<float>(Sleef_acosf16_u10(values));
}
Vectorized<float> asin() const {
return Vectorized<float>(Sleef_asinf16_u10(values));
}
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf16_u10(values));
}
Vectorized<float> atan2(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_atan2f16_u10(values, b));
}
Vectorized<float> copysign(const Vectorized<float> &sign) const {
return Vectorized<float>(Sleef_copysignf16(values, sign));
}
Vectorized<float> erf() const {
return Vectorized<float>(Sleef_erff16_u10(values));
}
Vectorized<float> erfc() const {
return Vectorized<float>(Sleef_erfcf16_u15(values));
}
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> exp() const {
return Vectorized<float>(Sleef_expf16_u10(values));
}
Vectorized<float> expm1() const {
return Vectorized<float>(Sleef_expm1f16_u10(values));
}
Vectorized<float> fmod(const Vectorized<float>& q) const {
return Vectorized<float>(Sleef_fmodf16(values, q));
}
Vectorized<float> log() const {
return Vectorized<float>(Sleef_logf16_u10(values));
}
Vectorized<float> log2() const {
return Vectorized<float>(Sleef_log2f16_u10(values));
}
Vectorized<float> log10() const {
return Vectorized<float>(Sleef_log10f16_u10(values));
}
Vectorized<float> log1p() const {
return Vectorized<float>(Sleef_log1pf16_u10(values));
}
Vectorized<float> frac() const;
Vectorized<float> sin() const {
return Vectorized<float>(Sleef_sinf16_u10(values));
}
Vectorized<float> sinh() const {
return Vectorized<float>(Sleef_sinhf16_u10(values));
}
Vectorized<float> cos() const {
return Vectorized<float>(Sleef_cosf16_u10(values));
}
Vectorized<float> cosh() const {
return Vectorized<float>(Sleef_coshf16_u10(values));
}
Vectorized<float> ceil() const {
return _mm512_ceil_ps(values);
}
Vectorized<float> floor() const {
return _mm512_floor_ps(values);
}
Vectorized<float> hypot(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_hypotf16_u05(values, b));
}
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> neg() const {
return _mm512_xor_ps(_mm512_set1_ps(-0.f), values);
}
Vectorized<float> nextafter(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_nextafterf16(values, b));
}
Vectorized<float> round() const {
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<float> tan() const {
return Vectorized<float>(Sleef_tanf16_u10(values));
}
Vectorized<float> tanh() const {
return Vectorized<float>(Sleef_tanhf16_u10(values));
}
Vectorized<float> trunc() const {
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<float> lgamma() const {
return Vectorized<float>(Sleef_lgammaf16_u10(values));
}
Vectorized<float> sqrt() const {
return _mm512_sqrt_ps(values);
}
Vectorized<float> reciprocal() const {
return _mm512_div_ps(_mm512_set1_ps(1), values);
}
Vectorized<float> rsqrt() const {
return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values));
}
Vectorized<float> pow(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_powf16_u10(values, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<float> operator==(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
};
template <>
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_add_ps(a, b);
}
template <>
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_sub_ps(a, b);
}
template <>
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_mul_ps(a, b);
}
template <>
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_div_ps(a, b);
}
// frac. Implement this here so we can use subtraction
Vectorized<float> Vectorized<float>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
auto zero_vec = _mm512_set1_epi32(0);
auto max = _mm512_max_ps(a, b);
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_ps(max, isnan);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
auto zero_vec = _mm512_set1_epi32(0);
auto min = _mm512_min_ps(a, b);
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_ps(min, isnan);
}
template <>
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
return _mm512_min_ps(max, _mm512_max_ps(min, a));
}
template <>
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
return _mm512_min_ps(max, a);
}
template <>
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
return _mm512_max_ps(min, a);
}
template <>
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_and_ps(a, b);
}
template <>
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_or_ps(a, b);
}
template <>
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_xor_ps(a, b);
}
Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
return (*this == other) & Vectorized<float>(1.0f);
}
Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
return (*this != other) & Vectorized<float>(1.0f);
}
Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
return (*this > other) & Vectorized<float>(1.0f);
}
Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
return (*this >= other) & Vectorized<float>(1.0f);
}
Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
return (*this < other) & Vectorized<float>(1.0f);
}
Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, float* dst, int64_t n) {
int64_t i;
#pragma unroll
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
_mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i));
}
#pragma unroll
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return _mm512_fmadd_ps(a, b, c);
}
#endif
}}}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,7 @@
#include <type_traits> #include <type_traits>
#include <bitset> #include <bitset>
#include <ATen/cpu/vec/vec256/intrinsics.h> #include <ATen/cpu/vec/intrinsics.h>
#include <ATen/native/Math.h> #include <ATen/native/Math.h>
#include <ATen/NumericUtils.h> #include <ATen/NumericUtils.h>
#include <c10/util/C++17.h> #include <c10/util/C++17.h>
@ -32,13 +32,28 @@
#include <c10/util/TypeCast.h> #include <c10/util/TypeCast.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
// These macros helped us unify vec_base.h
#ifdef CPU_CAPABILITY_AVX512
#if defined(__GNUC__) #if defined(__GNUC__)
#define __at_align32__ __attribute__((aligned(32))) #define __at_align__ __attribute__((aligned(64)))
#elif defined(_WIN32) #elif defined(_WIN32)
#define __at_align32__ __declspec(align(32)) #define __at_align__ __declspec(align(64))
#else #else
#define __at_align32__ #define __at_align__
#endif #endif
#define VECTOR_WIDTH 64
#define int_vector __m512i
#else // CPU_CAPABILITY_AVX512
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(32)))
#elif defined(_WIN32)
#define __at_align__ __declspec(align(32))
#else
#define __at_align__
#endif
#define VECTOR_WIDTH 32
#define int_vector __m256i
#endif // CPU_CAPABILITY_AVX512
namespace at { namespace at {
namespace vec { namespace vec {
@ -70,11 +85,11 @@ using int_same_size_t = typename int_of_size<sizeof(T)>::type;
// NOTE: If you specialize on a type, you must define all operations! // NOTE: If you specialize on a type, you must define all operations!
// emulates vectorized types // emulates Vectorized types
template <class T> template <class T>
struct Vectorized { struct Vectorized {
private: private:
__at_align32__ T values[32 / sizeof(T)]; __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
public: public:
using value_type = T; using value_type = T;
using size_type = int; using size_type = int;
@ -111,7 +126,7 @@ public:
// identifier is odr-used or not, and in any case it's hard to tell if // identifier is odr-used or not, and in any case it's hard to tell if
// a variable is odr-used or not. So best to just cut the problem at the root. // a variable is odr-used or not. So best to just cut the problem at the root.
static constexpr size_type size() { static constexpr size_type size() {
return 32 / sizeof(T); return VECTOR_WIDTH / sizeof(T);
} }
Vectorized() : values{0} {} Vectorized() : values{0} {}
Vectorized(T val) { Vectorized(T val) {
@ -134,60 +149,60 @@ public:
template <int64_t mask_> template <int64_t mask_>
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) { static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
int64_t mask = mask_; int64_t mask = mask_;
Vectorized vec; Vectorized vector;
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
if (mask & 0x01) { if (mask & 0x01) {
vec[i] = b[i]; vector[i] = b[i];
} else { } else {
vec[i] = a[i]; vector[i] = a[i];
} }
mask = mask >> 1; mask = mask >> 1;
} }
return vec; return vector;
} }
static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b, static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
const Vectorized<T>& mask) { const Vectorized<T>& mask) {
Vectorized vec; Vectorized vector;
int_same_size_t<T> buffer[size()]; int_same_size_t<T> buffer[size()];
mask.store(buffer); mask.store(buffer);
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
if (buffer[i] & 0x01) if (buffer[i] & 0x01)
{ {
vec[i] = b[i]; vector[i] = b[i];
} else { } else {
vec[i] = a[i]; vector[i] = a[i];
} }
} }
return vec; return vector;
} }
template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double) template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) { static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
Vectorized vec; Vectorized vector;
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
vec.values[i] = base + i * step; vector.values[i] = base + i * step;
} }
return vec; return vector;
} }
static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) { static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
Vectorized vec; Vectorized vector;
for (int64_t i = 0; i < size(); i++) { for (int64_t i = 0; i < size(); i++) {
if (i < count) { if (i < count) {
vec[i] = b[i]; vector[i] = b[i];
} else { } else {
vec[i] = a[i]; vector[i] = a[i];
} }
} }
return vec; return vector;
} }
static Vectorized<T> loadu(const void* ptr) { static Vectorized<T> loadu(const void* ptr) {
Vectorized vec; Vectorized vector;
std::memcpy(vec.values, ptr, 32); std::memcpy(vector.values, ptr, VECTOR_WIDTH);
return vec; return vector;
} }
static Vectorized<T> loadu(const void* ptr, int64_t count) { static Vectorized<T> loadu(const void* ptr, int64_t count) {
Vectorized vec; Vectorized vector;
std::memcpy(vec.values, ptr, count * sizeof(T)); std::memcpy(vector.values, ptr, count * sizeof(T));
return vec; return vector;
} }
void store(void* ptr, int count = size()) const { void store(void* ptr, int count = size()) const {
std::memcpy(ptr, values, count * sizeof(T)); std::memcpy(ptr, values, count * sizeof(T));
@ -203,15 +218,15 @@ public:
return mask; return mask;
} }
Vectorized<T> isnan() const { Vectorized<T> isnan() const {
Vectorized<T> vec; Vectorized<T> vector;
for (int64_t i = 0; i != size(); i++) { for (int64_t i = 0; i != size(); i++) {
if (_isnan(values[i])) { if (_isnan(values[i])) {
std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T)); std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
} else { } else {
std::memset(static_cast<void*>(vec.values + i), 0, sizeof(T)); std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
} }
} }
return vec; return vector;
} }
Vectorized<T> map(T (*const f)(T)) const { Vectorized<T> map(T (*const f)(T)) const {
Vectorized<T> ret; Vectorized<T> ret;
@ -488,15 +503,15 @@ private:
template <typename Op> template <typename Op>
inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const { inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
// All bits are set to 1 if the pred is true, otherwise 0. // All bits are set to 1 if the pred is true, otherwise 0.
Vectorized<T> vec; Vectorized<T> vector;
for (int64_t i = 0; i != size(); i++) { for (int64_t i = 0; i != size(); i++) {
if (op(values[i], other.values[i])) { if (op(values[i], other.values[i])) {
std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T)); std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
} else { } else {
std::memset(static_cast<void*>(vec.values + i), 0, sizeof(T)); std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
} }
} }
return vec; return vector;
} }
public: public:
@ -511,11 +526,11 @@ private:
template <typename Op> template <typename Op>
inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const { inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
// 1 if the pred is true, otherwise 0. // 1 if the pred is true, otherwise 0.
Vectorized<T> vec; Vectorized<T> vector;
for (int i = 0; i != size(); ++ i) { for (int i = 0; i != size(); ++ i) {
vec[i] = bool(op(values[i], other.values[i])); vector[i] = bool(op(values[i], other.values[i]));
} }
return vec; return vector;
} }
public: public:
@ -668,41 +683,62 @@ Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_
struct Vectorizedi; struct Vectorizedi;
#ifdef CPU_CAPABILITY_AVX2 #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
template <class T, typename Op> template <class T, typename Op>
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) { static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
__m256i buffer; int_vector buffer;
__m256i a_buffer = _mm256_loadu_si256(reinterpret_cast<const __m256i*>((const T*)a)); #if defined(CPU_CAPABILITY_AVX2)
__m256i b_buffer = _mm256_loadu_si256(reinterpret_cast<const __m256i*>((const T*)b)); int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
#elif defined(CPU_CAPABILITY_AVX512)
int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
#endif
buffer = op(a_buffer, b_buffer); buffer = op(a_buffer, b_buffer);
__at_align32__ T results[Vectorized<T>::size()]; __at_align__ T results[Vectorized<T>::size()];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(results), buffer);
#if defined(CPU_CAPABILITY_AVX2)
_mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
#elif defined(CPU_CAPABILITY_AVX512)
_mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
#endif
return Vectorized<T>::loadu(results); return Vectorized<T>::loadu(results);
} }
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) { inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
// We enclose _mm256_and_si256 with lambda because it is always_inline // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_and_si256(a, b); }); #if defined(CPU_CAPABILITY_AVX2)
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
#elif defined(CPU_CAPABILITY_AVX512)
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
#endif
} }
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) { inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
// We enclose _mm256_or_si256 with lambda because it is always_inline // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_or_si256(a, b); }); #if defined(CPU_CAPABILITY_AVX2)
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
#elif defined(CPU_CAPABILITY_AVX512)
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
#endif
} }
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) { inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
// We enclose _mm256_xor_si256 with lambda because it is always_inline // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }); #if defined(CPU_CAPABILITY_AVX2)
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
#elif defined(CPU_CAPABILITY_AVX512)
return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
#endif
} }
#else #else
template<class T, typename Op> template<class T, typename Op>
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) { static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
static constexpr uint32_t element_no = 32 / sizeof(intmax_t); static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
__at_align32__ intmax_t buffer[element_no]; __at_align__ intmax_t buffer[element_no];
const intmax_t *a_ptr = reinterpret_cast<const intmax_t*>((const T*) a); const intmax_t *a_ptr = reinterpret_cast<const intmax_t*>((const T*) a);
const intmax_t *b_ptr = reinterpret_cast<const intmax_t*>((const T*) b); const intmax_t *b_ptr = reinterpret_cast<const intmax_t*>((const T*) b);
for (uint32_t i = 0U; i < element_no; ++ i) { for (uint32_t i = 0U; i < element_no; ++ i) {
@ -724,12 +760,12 @@ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
return bitwise_binary_op(a, b, std::bit_xor<intmax_t>()); return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
} }
#endif #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
inline Vectorized<T> operator~(const Vectorized<T>& a) { inline Vectorized<T> operator~(const Vectorized<T>& a) {
Vectorized<T> ones; // All bits are 1 Vectorized<T> ones; // All bits are 1
memset((T*) ones, 0xFF, 32); memset((T*) ones, 0xFF, VECTOR_WIDTH);
return a ^ ones; return a ^ ones;
} }
@ -802,7 +838,9 @@ inline mask_gather(const Vectorized<T>& src, T const* base_addr,
} }
// Cast a given vector to another type without changing the bits representation. // Cast a given vector to another type without changing the bits representation.
// So a Vec<double> of 256 bits containing all ones can be cast to a // So a Vectorized<double> of 512 bits containing all ones can be cast to a
// Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
// A Vec<double> of 256 bits containing all ones can be cast to a
// Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s). // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
namespace { namespace {
// There is a struct here because we don't have static_if and I can't // There is a struct here because we don't have static_if and I can't
@ -840,10 +878,16 @@ inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectoriz
return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer)); return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
} }
// E.g., inputs: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3} // Example inputs for AVX512:
// b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7} // a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7} // b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7} // returns:
// Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
// Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
// b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
template <typename T> template <typename T>
inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>> inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) { deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
@ -866,8 +910,14 @@ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
} }
// inverse operation of deinterleave2 // inverse operation of deinterleave2
// E.g., inputs: a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7} // Example inputs for AVX512:
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7} // a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
// returns, for AVX512:
// Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
// Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
// returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3} // returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
// Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7} // Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
template <typename T> template <typename T>

View File

@ -35,21 +35,8 @@
#include <mkl.h> #include <mkl.h>
#endif #endif
// [Note SSE-AVX transitions]
// There is a bug in Glibc2.23
// https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall
// when using AVX/AVX2 code resolves this.
#if defined(CPU_CAPABILITY_AVX) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23
#define DL_RUNTIME_BUG(op, type_) \
using value_t = typename c10::scalar_value_type<type_>::type;\
volatile value_t x = (value_t)(1); \
x = std::op(x); \
_mm256_zeroall();
#define DL_RUNTIME_BUG_BFLOAT16() _mm256_zeroall();
#else
#define DL_RUNTIME_BUG(op, type_) #define DL_RUNTIME_BUG(op, type_)
#define DL_RUNTIME_BUG_BFLOAT16() #define DL_RUNTIME_BUG_BFLOAT16()
#endif
namespace at { namespace at {
namespace vml { namespace vml {
@ -117,36 +104,36 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
}); \ }); \
} }
IMPLEMENT_VML_BUG(abs) IMPLEMENT_VML(abs)
IMPLEMENT_VML_BUG(acos) IMPLEMENT_VML(acos)
IMPLEMENT_VML_BUG(asin) IMPLEMENT_VML(asin)
IMPLEMENT_VML_BUG(atan) IMPLEMENT_VML(atan)
IMPLEMENT_VML_BUG(ceil) IMPLEMENT_VML(ceil)
IMPLEMENT_VML_BUG(cos) IMPLEMENT_VML(cos)
// IMPLEMENT_VML_BUG(cosh) // IMPLEMENT_VML_BUG(cosh)
IMPLEMENT_VML_BUG(erf) IMPLEMENT_VML(erf)
IMPLEMENT_VML_BUG(erfc) IMPLEMENT_VML(erfc)
IMPLEMENT_VML(erfinv) IMPLEMENT_VML(erfinv)
IMPLEMENT_VML_BUG(exp) IMPLEMENT_VML(exp)
IMPLEMENT_VML_BUG(expm1) IMPLEMENT_VML(expm1)
IMPLEMENT_VML_BUG(floor) IMPLEMENT_VML(floor)
IMPLEMENT_VML(i0) IMPLEMENT_VML(i0)
IMPLEMENT_VML(i0e) IMPLEMENT_VML(i0e)
IMPLEMENT_VML(reciprocal) IMPLEMENT_VML(reciprocal)
IMPLEMENT_VML_BUG(log) IMPLEMENT_VML(log)
IMPLEMENT_VML_BUG(log10) IMPLEMENT_VML(log10)
IMPLEMENT_VML_BUG(log1p) IMPLEMENT_VML(log1p)
IMPLEMENT_VML_BUG(log2) IMPLEMENT_VML(log2)
IMPLEMENT_VML(neg) IMPLEMENT_VML(neg)
IMPLEMENT_VML_BUG(sin) IMPLEMENT_VML(sin)
// IMPLEMENT_VML_BUG(sinh) // IMPLEMENT_VML_BUG(sinh)
IMPLEMENT_VML_BUG(sqrt) IMPLEMENT_VML(sqrt)
IMPLEMENT_VML_BUG(round) IMPLEMENT_VML(round)
IMPLEMENT_VML(rsqrt) IMPLEMENT_VML(rsqrt)
IMPLEMENT_VML_BUG(tan) IMPLEMENT_VML(tan)
IMPLEMENT_VML_BUG(tanh) IMPLEMENT_VML(tanh)
IMPLEMENT_VML_BUG(trunc) IMPLEMENT_VML(trunc)
IMPLEMENT_VML_BUG(lgamma) IMPLEMENT_VML(lgamma)
#if AT_MKL_ENABLED() && !defined(__APPLE__) #if AT_MKL_ENABLED() && !defined(__APPLE__)

View File

@ -952,91 +952,109 @@ void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel); REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl); REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl); REGISTER_AVX512_DISPATCH(eig_stub, &eig_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl); REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel); REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel); REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel); REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel); REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel); REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel); REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
REGISTER_AVX_DISPATCH(lu_stub, &lu_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel); REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel); REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel); REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
REGISTER_AVX_DISPATCH(lu_solve_stub, &lu_solve_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
}} // namespace at::native }} // namespace at::native

View File

@ -16,12 +16,12 @@ static CPUCapability compute_cpu_capability() {
return CPUCapability::VSX; return CPUCapability::VSX;
} }
#else #else
if (strcmp(envar, "avx512") == 0) {
return CPUCapability::AVX512;
}
if (strcmp(envar, "avx2") == 0) { if (strcmp(envar, "avx2") == 0) {
return CPUCapability::AVX2; return CPUCapability::AVX2;
} }
if (strcmp(envar, "avx") == 0) {
return CPUCapability::AVX;
}
#endif #endif
if (strcmp(envar, "default") == 0) { if (strcmp(envar, "default") == 0) {
return CPUCapability::DEFAULT; return CPUCapability::DEFAULT;
@ -31,12 +31,13 @@ static CPUCapability compute_cpu_capability() {
#if !defined(__powerpc__) && !defined(__s390x__) #if !defined(__powerpc__) && !defined(__s390x__)
if (cpuinfo_initialize()) { if (cpuinfo_initialize()) {
if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && \
cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) {
return CPUCapability::AVX512;
}
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) { if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
return CPUCapability::AVX2; return CPUCapability::AVX2;
} }
if (cpuinfo_has_x86_avx()) {
return CPUCapability::AVX;
}
} }
#endif #endif
#ifdef HAVE_VSX_CPU_DEFINITION #ifdef HAVE_VSX_CPU_DEFINITION
@ -54,8 +55,8 @@ CPUCapability get_cpu_capability() {
void* DispatchStubImpl::get_call_ptr( void* DispatchStubImpl::get_call_ptr(
DeviceType device_type DeviceType device_type
, void *DEFAULT , void *DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX , void *AVX512
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2 , void *AVX2
@ -72,8 +73,8 @@ void* DispatchStubImpl::get_call_ptr(
if (!fptr) { if (!fptr) {
fptr = choose_cpu_impl( fptr = choose_cpu_impl(
DEFAULT DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
, AVX , AVX512
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
, AVX2 , AVX2
@ -102,8 +103,8 @@ void* DispatchStubImpl::get_call_ptr(
void* DispatchStubImpl::choose_cpu_impl( void* DispatchStubImpl::choose_cpu_impl(
void *DEFAULT void *DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX , void *AVX512
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2 , void *AVX2
@ -114,18 +115,26 @@ void* DispatchStubImpl::choose_cpu_impl(
) { ) {
auto capability = static_cast<int>(get_cpu_capability()); auto capability = static_cast<int>(get_cpu_capability());
(void)capability; (void)capability;
#ifdef HAVE_AVX512_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX512)) {
// Quantization kernels have also been disabled on Windows
// for AVX512 because some of their tests are flaky on Windows.
// Ideally, we should have AVX512 kernels for all kernels.
if (C10_UNLIKELY(!AVX512)) {
// dispatch to AVX2, since the AVX512 kernel is missing
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
return AVX2;
} else {
return AVX512;
}
}
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX2)) { if (capability >= static_cast<int>(CPUCapability::AVX2)) {
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
return AVX2; return AVX2;
} }
#endif #endif
#ifdef HAVE_AVX_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX)) {
TORCH_INTERNAL_ASSERT(AVX, "DispatchStub: missing AVX kernel");
return AVX;
}
#endif
#ifdef HAVE_VSX_CPU_DEFINITION #ifdef HAVE_VSX_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::VSX)) { if (capability >= static_cast<int>(CPUCapability::VSX)) {
TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel"); TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel");

View File

@ -9,8 +9,8 @@
// Implements instruction set specific function dispatch. // Implements instruction set specific function dispatch.
// //
// Kernels that may make use of specialized instruction sets (e.g. AVX) are // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
// compiled multiple times with different compiler flags (e.g. -mavx). A // compiled multiple times with different compiler flags (e.g. -mavx2). A
// DispatchStub contains a table of function pointers for a kernel. At runtime, // DispatchStub contains a table of function pointers for a kernel. At runtime,
// the fastest available kernel is chosen based on the features reported by // the fastest available kernel is chosen based on the features reported by
// cpuinfo. // cpuinfo.
@ -50,8 +50,8 @@ enum class CPUCapability {
#ifdef HAVE_VSX_CPU_DEFINITION #ifdef HAVE_VSX_CPU_DEFINITION
VSX = 1, VSX = 1,
#else #else
AVX = 1, AVX2 = 1,
AVX2 = 2, AVX512 = 2,
#endif #endif
NUM_OPTIONS NUM_OPTIONS
}; };
@ -71,8 +71,8 @@ struct TORCH_API DispatchStubImpl {
void* get_call_ptr( void* get_call_ptr(
DeviceType device_type DeviceType device_type
, void *DEFAULT , void *DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX , void *AVX512
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2 , void *AVX2
@ -89,8 +89,8 @@ struct TORCH_API DispatchStubImpl {
*/ */
void* choose_cpu_impl( void* choose_cpu_impl(
void *DEFAULT void *DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX , void *AVX512
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2 , void *AVX2
@ -126,8 +126,8 @@ private:
return reinterpret_cast<FnPtr>( return reinterpret_cast<FnPtr>(
impl.get_call_ptr(device_type impl.get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT) , reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX) , reinterpret_cast<void*>(AVX512)
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2) , reinterpret_cast<void*>(AVX2)
@ -155,8 +155,8 @@ public:
} }
static FnPtr DEFAULT; static FnPtr DEFAULT;
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
static FnPtr AVX; static FnPtr AVX512;
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
static FnPtr AVX2; static FnPtr AVX2;
@ -203,10 +203,10 @@ struct RegisterHIPDispatch {
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \ #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
template <> decltype(fn) DispatchStub<decltype(fn), struct name>::arch = fn; template <> decltype(fn) DispatchStub<decltype(fn), struct name>::arch = fn;
#ifdef HAVE_AVX_CPU_DEFINITION #ifdef HAVE_AVX512_CPU_DEFINITION
#define REGISTER_AVX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX, fn) #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
#else #else
#define REGISTER_AVX_DISPATCH(name, fn) #define REGISTER_AVX512_DISPATCH(name, fn)
#endif #endif
#ifdef HAVE_AVX2_CPU_DEFINITION #ifdef HAVE_AVX2_CPU_DEFINITION
@ -223,8 +223,8 @@ struct RegisterHIPDispatch {
#define REGISTER_NO_CPU_DISPATCH(name, fn_type) \ #define REGISTER_NO_CPU_DISPATCH(name, fn_type) \
REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast<fn_type>(nullptr)) \ REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast<fn_type>(nullptr)) \
REGISTER_AVX_DISPATCH(name, static_cast<fn_type>(nullptr)) \ REGISTER_AVX512_DISPATCH(name, static_cast<fn_type>(nullptr)) \
REGISTER_AVX2_DISPATCH(name, static_cast<fn_type>(nullptr)) \ REGISTER_AVX2_DISPATCH(name, static_cast<fn_type>(nullptr)) \
REGISTER_VSX_DISPATCH(name, static_cast<fn_type>(nullptr)) REGISTER_VSX_DISPATCH(name, static_cast<fn_type>(nullptr))
#define REGISTER_CUDA_DISPATCH(name, fn) \ #define REGISTER_CUDA_DISPATCH(name, fn) \
@ -244,6 +244,8 @@ struct RegisterHIPDispatch {
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
#elif defined(CPU_CAPABILITY) #elif defined(CPU_CAPABILITY)
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define REGISTER_NO_AVX512_DISPATCH(name, fn_type) \
REGISTER_AVX512_DISPATCH(name, static_cast<fn_type>(nullptr))
#endif #endif

View File

@ -275,10 +275,10 @@ REGISTER_ARCH_DISPATCH(
DEFAULT, DEFAULT,
&_segment_reduce_cpu_kernel); &_segment_reduce_cpu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
// Currently some computation is being duplicated across forward and backward. // Currently some computation is being duplicated across forward and backward.
@ -319,7 +319,7 @@ REGISTER_ARCH_DISPATCH(
DEFAULT, DEFAULT,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_backward_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_AVX_DISPATCH( REGISTER_AVX512_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_backward_stub,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_backward_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -4,7 +4,7 @@ The most important things to know:
compiled multiple times for different instruction sets.** Yes, compiled multiple times for different instruction sets.** Yes,
this folder is named `cpu`, but that doesn't mean put any old this folder is named `cpu`, but that doesn't mean put any old
CPU kernel it. Only put CPU kernels which need to be compiled CPU kernel it. Only put CPU kernels which need to be compiled
multiple times to take advantage of AVX/SSE instructions, but multiple times to take advantage of AVX512/AVX2/SSE instructions, but
only on processors that support them. only on processors that support them.
**Ensure that all implementations in this folder are put in an **Ensure that all implementations in this folder are put in an
@ -52,14 +52,14 @@ All of the `*.cpp` files in this folder will be compiled under all compiler
flags specified by `CPU_CAPABILITY_FLAGS` in `aten/src/ATen/CMakeLists.txt`. flags specified by `CPU_CAPABILITY_FLAGS` in `aten/src/ATen/CMakeLists.txt`.
The purpose of this is to allow the compilation with various compiler The purpose of this is to allow the compilation with various compiler
flags to enable features such as AVX instructions, while using runtime flags to enable features such as AVX2 or AVX512 instructions, while using
dispatch, which makes sure only valid instructions will be used on any runtime dispatch, which makes sure only valid instructions will be used on any
given platform. given platform.
Vectorized.h provides a generic implementation of a vec type that allows vec.h provides a generic implementation of vec type that allows
the programmer to write code packing various primitives (such as floats) the programmer to write code packing various primitives (such as floats)
within 256bit registers. vec defines various operators such as + and * within 256bit & 512bits registers. vec defines various operators such as
and provides functions to allow operations such as max, min, etc. + and * and provides functions to allow operations such as max, min, etc.
As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces
an entire array using a given associative binary operation such as +. an entire array using a given associative binary operation such as +.
@ -74,5 +74,5 @@ generic code, which will be compiled under multipled compilation settings.
`../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains `../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains
a generic definition of `sumImplAll`. This function allows the user to reduce a generic definition of `sumImplAll`. This function allows the user to reduce
over a dimension or all dimensions. The appropiate capability is chosen at over a dimension or all dimensions. The appropiate capability is chosen at
runtime using cpuinfo. If the current platform has AVX, `sumImpl` will be set runtime using cpuinfo. If the current platform has AVX2, `sumImpl` will be set
to `sumImplAll<CPUCapability::AVX>`. to `sumImplAll<CPUCapability::AVX2>`.

View File

@ -32,7 +32,8 @@ static inline bool is_outer_reduction(const int64_t* strides) {
} }
template <typename func_t, typename vec_func_t> template <typename func_t, typename vec_func_t>
static inline void reduction128(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) { static inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
func_t op, vec_func_t vop, bool reduce) {
VEC_LOOP_HEADER(func_t, data) VEC_LOOP_HEADER(func_t, data)
const char* in1_ptr = data[1]; const char* in1_ptr = data[1];
Vec acc[4]; Vec acc[4];
@ -80,7 +81,7 @@ static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op,
int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
int64_t count = n / (4 * Vec::size()); int64_t count = n / (4 * Vec::size());
if (count > 0) { if (count > 0) {
reduction128(data, count, vector_stride, op, vop, /*reduce=*/true); vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
} }
char* ptrs[3] = { data[0], data[0], data[1] }; char* ptrs[3] = { data[0], data[0], data[1] };
int64_t strides[] = { 0, 0, sizeof(scalar_t) }; int64_t strides[] = { 0, 0, sizeof(scalar_t) };
@ -92,10 +93,14 @@ template <typename func_t, typename vec_func_t>
static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
VEC_LOOP_HEADER(func_t, data) VEC_LOOP_HEADER(func_t, data)
// reduce down each column of 4 * Vec::size() elements (128 bytes) // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
#if defined(CPU_CAPABILITY_AVX512)
int64_t outer_stride[2] = { 256, 256 };
#else
int64_t outer_stride[2] = { 128, 128 }; int64_t outer_stride[2] = { 128, 128 };
#endif
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
reduction128(data, size0, inner_stride, op, vop, /*reduce=*/false); vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
}); });
// reduce down the remaining columns // reduce down the remaining columns

View File

@ -219,9 +219,15 @@ inline void _vec_softmax(
int64_t outer_stride = dim_size * dim_stride; int64_t outer_stride = dim_size * dim_stride;
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32 int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32
TORCH_CHECK( #ifdef CPU_CAPABILITY_AVX512
TORCH_INTERNAL_ASSERT(
(vectorized_step == 16) || (vectorized_step == 8),
"vectorized_step must be 16 with dtype float or 8 with dtype double");
#else
TORCH_INTERNAL_ASSERT(
(vectorized_step == 8) || (vectorized_step == 4), (vectorized_step == 8) || (vectorized_step == 4),
"vectorized_step must be 8 with dtype float or 4 with dtype double"); "vectorized_step must be 8 with dtype float or 4 with dtype double");
#endif
parallel_for( parallel_for(
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
int64_t idx = begin; int64_t idx = begin;

View File

@ -611,8 +611,15 @@ void nansum_kernel_impl(TensorIterator &iter) {
} // namespace (anonymous) } // namespace (anonymous)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl); REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
// nansum on Float16 has poor accuracy with AVX2, and more so with AVX512.
// So until it's fixed, it won't be dispatched with AVX512. GH issue 59415.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
#ifndef CPU_CAPABILITY_AVX512
REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl); REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl);
#else
REGISTER_NO_AVX512_DISPATCH(nansum_stub, reduce_fn);
#endif
}} // namespace at::native }} // namespace at::native

View File

@ -715,8 +715,15 @@ REGISTER_DISPATCH(exponential_stub, &CPU_CAPABILITY::exponential_kernel);
REGISTER_DISPATCH(geometric_stub, &CPU_CAPABILITY::geometric_kernel); REGISTER_DISPATCH(geometric_stub, &CPU_CAPABILITY::geometric_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(log_normal_stub, &CPU_CAPABILITY::log_normal_kernel); REGISTER_DISPATCH(log_normal_stub, &CPU_CAPABILITY::log_normal_kernel);
#ifdef CPU_CAPABILITY_AVX512
// normal_stub isn't being dispatched to AVX512 because it exposes
// flakiness in test_sgd of test/test_optim.py
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(normal_stub, void(*)(Tensor&, const double, const double, c10::optional<Generator>));
#else
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(normal_stub, &CPU_CAPABILITY::normal_kernel); REGISTER_DISPATCH(normal_stub, &CPU_CAPABILITY::normal_kernel);
#endif
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(uniform_stub, &CPU_CAPABILITY::uniform_kernel); REGISTER_DISPATCH(uniform_stub, &CPU_CAPABILITY::uniform_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -32,26 +32,17 @@
#include <ATen/native/cpu/Intrinsics.h> #include <ATen/native/cpu/Intrinsics.h>
/* yes I know, the top of this file is quite ugly */ /* The original source of this file has been modified. */
#if defined(CPU_CAPABILITY_AVX2)
#if defined(__GNUC__) #if defined(__GNUC__)
# define ALIGN32_BEG __attribute__((aligned(32))) # define ALIGN32_BEG __attribute__((aligned(32)))
#elif defined(_WIN32) #elif defined(_WIN32)
# define ALIGN32_BEG __declspec(align(32)) # define ALIGN32_BEG __declspec(align(32))
#endif #endif
/* __m128 is ugly to write */ typedef __m256 v8sf; // vector of 8 float (avx2)
typedef __m256 v8sf; // vector of 8 float (avx) typedef __m256i v8si; // vector of 8 int (avx2)
typedef __m256i v8si; // vector of 8 int (avx)
typedef __m128i v4si; // vector of 8 int (avx)
#define _PI32AVX_CONST(Name, Val) \
static const ALIGN32_BEG int _pi32avx_##Name[4] = { Val, Val, Val, Val }
_PI32AVX_CONST(1, 1);
_PI32AVX_CONST(inv1, ~1);
_PI32AVX_CONST(2, 2);
_PI32AVX_CONST(4, 4);
/* declare some AVX constants -- why can't I figure a better way to do that? */ /* declare some AVX constants -- why can't I figure a better way to do that? */
#define _PS256_CONST(Name, Val) \ #define _PS256_CONST(Name, Val) \
@ -91,67 +82,6 @@ _PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
_PS256_CONST(cephes_log_q1, -2.12194440e-4); _PS256_CONST(cephes_log_q1, -2.12194440e-4);
_PS256_CONST(cephes_log_q2, 0.693359375); _PS256_CONST(cephes_log_q2, 0.693359375);
#ifndef CPU_CAPABILITY_AVX2
typedef union imm_xmm_union {
v8si imm;
v4si xmm[2];
} imm_xmm_union;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) { \
imm_xmm_union u __attribute__((aligned(32))); \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) { \
imm_xmm_union u __attribute__((aligned(32))); \
u.xmm[0]=xmm0_; u.xmm[1]=xmm1_; imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline v8si _mm256_##fn(v8si x, int a) \
{ \
/* use SSE2 instruction to perform the bitop AVX2 */ \
v4si x1, x2; \
v8si ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1,a); \
x2 = _mm_##fn(x2,a); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return(ret); \
}
#warning "Using SSE2 to perform AVX2 bitshift ops"
AVX2_BITOP_USING_SSE2(slli_epi32)
AVX2_BITOP_USING_SSE2(srli_epi32)
#define AVX2_INTOP_USING_SSE2(fn) \
static inline v8si _mm256_##fn(v8si x, v8si y) \
{ \
/* use SSE2 instructions to perform the AVX2 integer operation */ \
v4si x1, x2; \
v4si y1, y2; \
v8si ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1,y1); \
x2 = _mm_##fn(x2,y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return(ret); \
}
#warning "Using SSE2 to perform AVX2 integer ops"
AVX2_INTOP_USING_SSE2(and_si128)
AVX2_INTOP_USING_SSE2(andnot_si128)
AVX2_INTOP_USING_SSE2(cmpeq_epi32)
AVX2_INTOP_USING_SSE2(sub_epi32)
AVX2_INTOP_USING_SSE2(add_epi32)
#endif /* CPU_CAPABILITY_AVX2 */
/* natural logarithm computed for 8 simultaneous float /* natural logarithm computed for 8 simultaneous float
return NaN for x <= 0 return NaN for x <= 0
@ -326,11 +256,6 @@ inline v8sf sin256_ps(v8sf x) { // any x
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
v8si imm0, imm2; v8si imm0, imm2;
#ifndef CPU_CAPABILITY_AVX2
v4si imm0_1, imm0_2;
v4si imm2_1, imm2_2;
#endif
sign_bit = x; sign_bit = x;
/* take the absolute value */ /* take the absolute value */
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
@ -346,7 +271,6 @@ inline v8sf sin256_ps(v8sf x) { // any x
If we don't have AVX, let's perform them using SSE2 directives If we don't have AVX, let's perform them using SSE2 directives
*/ */
#ifdef CPU_CAPABILITY_AVX2
/* store the integer part of y in mm0 */ /* store the integer part of y in mm0 */
imm2 = _mm256_cvttps_epi32(y); imm2 = _mm256_cvttps_epi32(y);
/* j=(j+1) & (~1) (see the cephes sources) */ /* j=(j+1) & (~1) (see the cephes sources) */
@ -366,35 +290,6 @@ inline v8sf sin256_ps(v8sf x) { // any x
*/ */
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0); imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
#else
/* we use SSE2 routines to perform the integer ops */
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
y = _mm256_cvtepi32_ps(imm2);
imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
imm0_1 = _mm_slli_epi32(imm0_1, 29);
imm0_2 = _mm_slli_epi32(imm0_2, 29);
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
v8sf swap_sign_bit = _mm256_castsi256_ps(imm0); v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
v8sf poly_mask = _mm256_castsi256_ps(imm2); v8sf poly_mask = _mm256_castsi256_ps(imm2);
@ -453,18 +348,12 @@ inline v8sf cos256_ps(v8sf x) { // any x
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y; v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
v8si imm0, imm2; v8si imm0, imm2;
#ifndef CPU_CAPABILITY_AVX2
v4si imm0_1, imm0_2;
v4si imm2_1, imm2_2;
#endif
/* take the absolute value */ /* take the absolute value */
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
/* scale by 4/Pi */ /* scale by 4/Pi */
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
#ifdef CPU_CAPABILITY_AVX2
/* store the integer part of y in mm0 */ /* store the integer part of y in mm0 */
imm2 = _mm256_cvttps_epi32(y); imm2 = _mm256_cvttps_epi32(y);
/* j=(j+1) & (~1) (see the cephes sources) */ /* j=(j+1) & (~1) (see the cephes sources) */
@ -479,39 +368,6 @@ inline v8sf cos256_ps(v8sf x) { // any x
/* get the polynom selection mask */ /* get the polynom selection mask */
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0); imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
#else
/* we use SSE2 routines to perform the integer ops */
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
y = _mm256_cvtepi32_ps(imm2);
imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2);
imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2);
imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4);
imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4);
imm0_1 = _mm_slli_epi32(imm0_1, 29);
imm0_2 = _mm_slli_epi32(imm0_2, 29);
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
v8sf sign_bit = _mm256_castsi256_ps(imm0); v8sf sign_bit = _mm256_castsi256_ps(imm0);
v8sf poly_mask = _mm256_castsi256_ps(imm2); v8sf poly_mask = _mm256_castsi256_ps(imm2);
@ -571,12 +427,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y; v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
v8si imm0, imm2, imm4; v8si imm0, imm2, imm4;
#ifndef CPU_CAPABILITY_AVX2
v4si imm0_1, imm0_2;
v4si imm2_1, imm2_2;
v4si imm4_1, imm4_2;
#endif
sign_bit_sin = x; sign_bit_sin = x;
/* take the absolute value */ /* take the absolute value */
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
@ -586,7 +436,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
/* scale by 4/Pi */ /* scale by 4/Pi */
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
#ifdef CPU_CAPABILITY_AVX2
/* store the integer part of y in imm2 */ /* store the integer part of y in imm2 */
imm2 = _mm256_cvttps_epi32(y); imm2 = _mm256_cvttps_epi32(y);
@ -606,38 +455,7 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0); imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
//v8sf poly_mask = _mm256_castsi256_ps(imm2); //v8sf poly_mask = _mm256_castsi256_ps(imm2);
#else
/* we use SSE2 routines to perform the integer ops */
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
y = _mm256_cvtepi32_ps(imm2);
imm4_1 = imm2_1;
imm4_2 = imm2_2;
imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
imm0_1 = _mm_slli_epi32(imm0_1, 29);
imm0_2 = _mm_slli_epi32(imm0_2, 29);
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0); v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
v8sf poly_mask = _mm256_castsi256_ps(imm2); v8sf poly_mask = _mm256_castsi256_ps(imm2);
@ -653,22 +471,9 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
x = _mm256_add_ps(x, xmm2); x = _mm256_add_ps(x, xmm2);
x = _mm256_add_ps(x, xmm3); x = _mm256_add_ps(x, xmm3);
#ifdef CPU_CAPABILITY_AVX2
imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2); imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4); imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
imm4 = _mm256_slli_epi32(imm4, 29); imm4 = _mm256_slli_epi32(imm4, 29);
#else
imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2);
imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2);
imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4);
imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4);
imm4_1 = _mm_slli_epi32(imm4_1, 29);
imm4_2 = _mm_slli_epi32(imm4_2, 29);
COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
#endif
v8sf sign_bit_cos = _mm256_castsi256_ps(imm4); v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
@ -713,3 +518,5 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
*s = _mm256_xor_ps(xmm1, sign_bit_sin); *s = _mm256_xor_ps(xmm1, sign_bit_sin);
*c = _mm256_xor_ps(xmm2, sign_bit_cos); *c = _mm256_xor_ps(xmm2, sign_bit_cos);
} }
#endif // CPU_CAPABILITY_AVX2

View File

@ -149,8 +149,8 @@ static void _fft_fill_with_conjugate_symmetry_cpu_(
// Register this one implementation for all cpu types instead of compiling multiple times // Register this one implementation for all cpu types instead of compiling multiple times
REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
// _out variants can be shared between PocketFFT and MKL // _out variants can be shared between PocketFFT and MKL
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,

View File

@ -6,6 +6,7 @@
#include <ATen/native/UpSample.h> #include <ATen/native/UpSample.h>
#include <ATen/native/cpu/Loops.h> #include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/affine_quantizer.h> #include <ATen/native/quantized/affine_quantizer.h>
#include <ATen/native/quantized/fake_quant_affine.h>
#include <ATen/native/quantized/cpu/quantized_ops.h> #include <ATen/native/quantized/cpu/quantized_ops.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
@ -197,7 +198,28 @@ int64_t hsum(const uint8_t* A, int len) {
for (const auto k : c10::irange(8)) { for (const auto k : c10::irange(8)) {
row_sum += temp[k]; row_sum += temp[k];
} }
#endif // CPU_CAPABILITY_AVX2 #elif defined(CPU_CAPABILITY_AVX512)
__m512i sum_v = _mm512_setzero_si512();
__m512i one_epi16_v = _mm512_set1_epi16(1);
__m512i one_epi8_v = _mm512_set1_epi8(1);
// vectorized
for (; i < len / 64 * 64; i += 64) {
__m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
sum_v = _mm512_add_epi32(
sum_v,
_mm512_madd_epi16(
// first argument is unsigned, second is signed
_mm512_maddubs_epi16(src_v, one_epi8_v),
one_epi16_v)
);
}
alignas(64) int32_t temp[16];
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
for (const auto k : c10::irange(16)) {
row_sum += temp[k];
}
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
// scalar // scalar
for (; i < len; ++i) { for (; i < len; ++i) {
@ -233,7 +255,28 @@ int64_t hsum(const int8_t* A, int len) {
for (const auto k : c10::irange(8)) { for (const auto k : c10::irange(8)) {
row_sum += temp[k]; row_sum += temp[k];
} }
#endif // CPU_CAPABILITY_AVX2 #elif defined(CPU_CAPABILITY_AVX512)
__m512i sum_v = _mm512_setzero_si512();
__m512i one_epi16_v = _mm512_set1_epi16(1);
__m512i one_epi8_v = _mm512_set1_epi8(1);
// vectorized
for (; i < len / 64 * 64; i += 64) {
__m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
sum_v = _mm512_add_epi32(
sum_v,
_mm512_madd_epi16(
// first argument is unsigned, second is signed
_mm512_maddubs_epi16(one_epi8_v, src_v),
one_epi16_v)
);
}
alignas(64) int32_t temp[16];
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
for (const auto k : c10::irange(16)) {
row_sum += temp[k];
}
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
// scalar // scalar
for (; i < len; ++i) { for (; i < len; ++i) {
@ -255,7 +298,7 @@ int64_t hsum(const int32_t* A, int len) {
__m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
// widen // widen
__m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32); __m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32);
__m128i src_hi_epi32 = _mm256_extractf128_si256(src_epi32, 1); __m128i src_hi_epi32 = _mm256_extracti128_si256(src_epi32, 1);
__m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32); __m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32);
__m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32); __m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32);
// add // add
@ -268,7 +311,27 @@ int64_t hsum(const int32_t* A, int len) {
for (const auto k : c10::irange(4)) { for (const auto k : c10::irange(4)) {
row_sum += temp[k]; row_sum += temp[k];
} }
#endif // CPU_CAPABILITY_AVX2 #elif defined(CPU_CAPABILITY_AVX512)
__m512i sum_epi64 = _mm512_setzero_si512();
// vectorized
for (; i < len / 16 * 16; i += 16) {
__m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
// widen
__m256i src_lo_epi32 = _mm512_castsi512_si256(src_epi32);
__m256i src_hi_epi32 = _mm512_extracti32x8_epi32(src_epi32, 1);
__m512i src_lo_epi64 = _mm512_cvtepi32_epi64(src_lo_epi32);
__m512i src_hi_epi64 = _mm512_cvtepi32_epi64(src_hi_epi32);
// add
sum_epi64 = _mm512_add_epi64(sum_epi64, src_lo_epi64);
sum_epi64 = _mm512_add_epi64(sum_epi64, src_hi_epi64);
}
alignas(64) int64_t temp[8];
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_epi64);
for (const auto k : c10::irange(8)) {
row_sum += temp[k];
}
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
// scalar // scalar
for (; i < len; ++i) { for (; i < len; ++i) {
@ -313,7 +376,36 @@ int64_t hsum_sq(const uint8_t* A, int len) {
} }
sum_v_epu32 = _mm256_setzero_si256(); sum_v_epu32 = _mm256_setzero_si256();
} }
#endif // CPU_CAPABILITY_AVX2 #elif defined(CPU_CAPABILITY_AVX512)
__m512i sum_v_epu32 = _mm512_setzero_si512();
alignas(64) int32_t temp[16];
int overflow_threshold = 262144; // 2147483647(max of int32)/(512*512)*8 = 262144
int loop = len / overflow_threshold + 1;
for(int j=0; j<=loop; j++){
for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
// (i31, ..., i0)
__m256i src_epu8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
__m512i src_epu16 = _mm512_cvtepu8_epi16(src_epu8);
// (i31 ^ 2, ..., i0 ^ 2)
__m512i sq_epu16 = _mm512_mullo_epi16(src_epu16, src_epu16);
// (i15 ^ 2, ..., i0 ^ 2)
__m256i sq_lo_epu16 = _mm512_castsi512_si256(sq_epu16);
// (i31 ^ 2, ..., i16 ^ 2)
__m256i sq_hi_epu16 = _mm512_extracti32x8_epi32(sq_epu16, 1);
// widen to epu32
__m512i sq_lo_epu32 = _mm512_cvtepu16_epi32(sq_lo_epu16);
__m512i sq_hi_epu32 = _mm512_cvtepu16_epi32(sq_hi_epu16);
// add to running sum
sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_lo_epu32);
sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_hi_epu32);
}
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epu32);
for (const auto k : c10::irange(16)) {
row_sum += temp[k];
}
sum_v_epu32 = _mm512_setzero_si512();
}
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
// scalar // scalar
for (; i < len; ++i) { for (; i < len; ++i) {
@ -361,7 +453,40 @@ int64_t hsum_sq(const int8_t* A, int len) {
} }
sum_v_epi32 = _mm256_setzero_si256(); sum_v_epi32 = _mm256_setzero_si256();
} }
#endif // CPU_CAPABILITY_AVX2 #elif defined(CPU_CAPABILITY_AVX512)
// vectorized
__m512i sum_v_epi32 = _mm512_setzero_si512();
alignas(64) int32_t temp[16];
int overflow_threshold = 1048576; //2147483647/(256*256)*8 = 1048576
int loop = len / overflow_threshold + 1;
for(int j=0; j<=loop; j++){
for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
// (i31, ..., i0)
__m256i src_epi8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
__m512i src_epi16 = _mm512_cvtepi8_epi16(src_epi8);
// (i31 ^ 2, ..., i0 ^ 2)
__m512i sq_epi16 = _mm512_mullo_epi16(src_epi16, src_epi16);
// (i15 ^ 2, ..., i0 ^ 2)
__m256i sq_lo_epi16 = _mm512_castsi512_si256(sq_epi16);
// (i31 ^ 2, ..., i16 ^ 2)
__m256i sq_hi_epi16 = _mm512_extracti32x8_epi32(sq_epi16, 1);
// widen to epi32
__m512i sq_lo_epi32 = _mm512_cvtepi16_epi32(sq_lo_epi16);
__m512i sq_hi_epi32 = _mm512_cvtepi16_epi32(sq_hi_epi16);
// add to running sum
sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_lo_epi32);
sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_hi_epi32);
}
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epi32);
for (const auto k : c10::irange(16)) {
row_sum += temp[k];
}
sum_v_epi32 = _mm512_setzero_si512();
}
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
// scalar // scalar
for (; i < len; ++i) { for (; i < len; ++i) {
@ -391,7 +516,21 @@ float hsum_sq(const int32_t* A, int len) {
for (const auto k : c10::irange(8)) { for (const auto k : c10::irange(8)) {
row_sum += static_cast<float>(temp[k]); row_sum += static_cast<float>(temp[k]);
} }
#endif // CPU_CAPABILITY_AVX2 #elif defined(CPU_CAPABILITY_AVX512)
__m512 sum_ps = _mm512_setzero_ps();
// vectorized
for (; i < len / 16 * 16; i += 16) {
__m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
__m512 src_ps = _mm512_cvtepi32_ps(src_epi32);
sum_ps = _mm512_add_ps(sum_ps, _mm512_mul_ps(src_ps, src_ps));
}
alignas(64) float temp[16];
_mm512_store_ps(temp, sum_ps);
for (const auto k : c10::irange(16)) {
row_sum += static_cast<float>(temp[k]);
}
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
// scalar // scalar
for (; i < len; ++i) { for (; i < len; ++i) {
@ -1239,7 +1378,7 @@ void qmaxpool_2d_nhwc_kernel(
} }
template <typename T> template <typename T>
void do_avg_pool_nhwc_on_AVX2( void do_avg_pool_nhwc_on_AVX_n(
const typename T::underlying* i_p, const typename T::underlying* i_p,
typename T::underlying* o_p, typename T::underlying* o_p,
int& c_start, int& c_start,
@ -1256,17 +1395,25 @@ void do_avg_pool_nhwc_on_AVX2(
int hsize, int hsize,
int wsize, int wsize,
int csize) { int csize) {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
// buffer for channel accumulator, used to interchange channel-loop // buffer for channel accumulator, used to interchange channel-loop
// to inner-most, so that memory access of the input tensor data is // to inner-most, so that memory access of the input tensor data is
// continuous. // continuous.
#ifdef CPU_CAPABILITY_AVX2
constexpr int cb_size = 16; constexpr int cb_size = 16;
#else
constexpr int cb_size = 8;
#endif
constexpr int vec_width = Vectorized<T>::size() / 4; constexpr int vec_width = Vectorized<T>::size() / 4;
constexpr int cb_step = cb_size * vec_width; constexpr int cb_step = cb_size * vec_width;
Vectorized<int32_t> acc_buffer[cb_size]; Vectorized<int32_t> acc_buffer[cb_size];
Vectorized<float> acc_buffer_fp[cb_size]; Vectorized<float> acc_buffer_fp[cb_size];
#ifdef CPU_CAPABILITY_AVX2
if (vec_width == 8) { if (vec_width == 8) {
#else
if (vec_width == 16) {
#endif
for (int c = c_start; c < csize; c += cb_step) { for (int c = c_start; c < csize; c += cb_step) {
int cend = std::min(cb_size, (csize - c) / vec_width); int cend = std::min(cb_size, (csize - c) / vec_width);
// initialize loop // initialize loop
@ -1292,14 +1439,23 @@ void do_avg_pool_nhwc_on_AVX2(
// convert int32 accumulative to fp32 // convert int32 accumulative to fp32
vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width); vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width);
// first quantize using AVX using 32 lanes, then 8, finally falls // first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls
// back to single // back to single
#ifdef CPU_CAPABILITY_AVX2
QuantizeAvx2<T>( QuantizeAvx2<T>(
(float*)acc_buffer_fp, (float*)acc_buffer_fp,
o_p + c, o_p + c,
cend * vec_width, cend * vec_width,
multiplier, multiplier,
output_zero_point); output_zero_point);
#else
QuantizeAvx512<T>(
(float*)acc_buffer_fp,
o_p + c,
cend * vec_width,
multiplier,
output_zero_point);
#endif
} }
c_start = csize / vec_width * vec_width; c_start = csize / vec_width * vec_width;
} }
@ -1307,7 +1463,7 @@ void do_avg_pool_nhwc_on_AVX2(
} }
template <typename T> template <typename T>
void do_avg_pool_on_AVX2( void do_avg_pool_on_AVX_n(
typename T::underlying* i_p, typename T::underlying* i_p,
typename T::underlying* o_p, typename T::underlying* o_p,
int64_t& c, int64_t& c,
@ -1326,9 +1482,13 @@ void do_avg_pool_on_AVX2(
int64_t stride_D, int64_t stride_D,
int64_t stride_H, int64_t stride_H,
int64_t stride_W) { int64_t stride_W) {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
constexpr auto vec_width = Vectorized<T>::size() / 4; constexpr int vec_width = Vectorized<T>::size() / 4;
#ifdef CPU_CAPABILITY_AVX2
if (vec_width == 8) { if (vec_width == 8) {
#else
if (vec_width == 16) {
#endif
for (; c + vec_width <= channel_size; c += vec_width) { for (; c + vec_width <= channel_size; c += vec_width) {
int64_t tcntr = 0; int64_t tcntr = 0;
@ -1416,10 +1576,10 @@ void _qadaptive_avg_pool_kernel(
istartH * istrideH + istartH * istrideH +
istartW * istrideW; istartW * istrideW;
// Note: If AVX is not available, `do_avg_pool_on_AVX2 is a noop. // Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
// In that case, the following loop takes over // In that case, the following loop takes over
// TODO: more vectorization with loop interleaving // TODO: more vectorization with loop interleaving
do_avg_pool_on_AVX2<scalar_t>( do_avg_pool_on_AVX_n<scalar_t>(
internal_i_p, internal_i_p,
o_p, o_p,
c, c,
@ -1438,7 +1598,6 @@ void _qadaptive_avg_pool_kernel(
istrideD, istrideD,
istrideH, istrideH,
istrideW); istrideW);
// 1) The following loop handles the remaining channels // 1) The following loop handles the remaining channels
// 2) It also handles the Non-AVX2 path // 2) It also handles the Non-AVX2 path
for (; c < sizeC; ++c) { for (; c < sizeC; ++c) {
@ -1610,7 +1769,7 @@ void _qavg_pool_nhwc_kernel(
// For int8 quantization, we implicitly use int32 as accumulation // For int8 quantization, we implicitly use int32 as accumulation
// Or else, it will go to the slow path // Or else, it will go to the slow path
// TODO: support 16bit, 32bit, and etc. // TODO: support 16bit, 32bit, and etc.
do_avg_pool_nhwc_on_AVX2<scalar_t>( do_avg_pool_nhwc_on_AVX_n<scalar_t>(
i_p, i_p,
o_p, o_p,
c_start, c_start,
@ -1744,7 +1903,7 @@ void qavg_pool3d_nhwc_kernel(
} }
template <typename T> template <typename T>
int64_t do_quantized_bilinear_on_AVX2( int64_t do_quantized_bilinear_on_AVX_n(
const typename T::underlying*& pos1, const typename T::underlying*& pos1,
typename T::underlying*& pos2, typename T::underlying*& pos2,
int64_t input_height, int64_t input_height,
@ -1762,9 +1921,13 @@ int64_t do_quantized_bilinear_on_AVX2(
const int64_t h1p, const int64_t h1p,
const int64_t w1p) { const int64_t w1p) {
int64_t c = 0; int64_t c = 0;
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
constexpr auto vec_width = Vectorized<T>::size() / 4; constexpr auto vec_width = Vectorized<T>::size() / 4;
#ifdef CPU_CAPABILITY_AVX2
if (vec_width == 8) { if (vec_width == 8) {
#else
if (vec_width == 16) {
#endif
for (; c + vec_width <= channels; c += vec_width) { for (; c + vec_width <= channels; c += vec_width) {
Vectorized<float> pos1_fp_v[4]; Vectorized<float> pos1_fp_v[4];
Vectorized<int32_t> pos1_int_v[4]; Vectorized<int32_t> pos1_int_v[4];
@ -1861,7 +2024,7 @@ void qupsample_bilinear2d_nhwc_kernel(
o_p + (h2 * output_width + w2) * channels; o_p + (h2 * output_width + w2) * channels;
// We have to isolate this function out because the VS does not // We have to isolate this function out because the VS does not
// expand the macro correctly. // expand the macro correctly.
c = do_quantized_bilinear_on_AVX2<scalar_t>( c = do_quantized_bilinear_on_AVX_n<scalar_t>(
pos1, pos1,
pos2, pos2,
input_height, input_height,
@ -1989,7 +2152,7 @@ void q_batch_norm_kernel(
reinterpret_cast<scalar_t::underlying*>(input.data_ptr()); reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr()); scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
constexpr int kVLen = 8; constexpr int kVLen = Vectorized<float>::size();
const int64_t outer_size = N * HxW; const int64_t outer_size = N * HxW;
using Vec = Vectorized<scalar_t>; using Vec = Vectorized<scalar_t>;
// Hoisted variables // Hoisted variables
@ -2292,7 +2455,7 @@ void quantized_normalize_kernel(
float y_scale = Y->q_scale(); float y_scale = Y->q_scale();
float y_inv_scale = 1.0f / y_scale; float y_inv_scale = 1.0f / y_scale;
constexpr int kFloatVLen = 8; constexpr int kFloatVLen = fVec::size();
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs(); int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
int64_t kNumIntVecInLayer = N / kIntVLen; int64_t kNumIntVecInLayer = N / kIntVLen;
int64_t kNonVecRemInLayer = N % kIntVLen; int64_t kNonVecRemInLayer = N % kIntVLen;
@ -3095,6 +3258,114 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
} // namespace } // namespace
// Some quantization tests are flaky on Windows with AVX512. If --continue-through-error
// is used, only one fails. But if the failing test is skipped, another one fails.
// If the second test is also skipped, a third one fails.
// So, until Quantization support for Windows is fixed for AVX512,
// AVX2 kernels would be used instead. Ref: GH 56992.
#if defined(CPU_CAPABILITY_AVX512) && defined(_WIN32)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub,
dequantize_tensor_per_channel_affine_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
dequantize_tensor_per_tensor_affine_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
dequantize_tensor_per_channel_float_qparams_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_tensor_stub,
fake_quant_learnable_grad_tensor_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub,
fake_quant_per_channel_cachemask_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_stub,
fake_quant_tensor_cachemask_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
fake_quant_tensor_cachemask_tensor_qparams_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
qadaptive_avg_pool2d_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
qadaptive_avg_pool3d_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qadd_relu_stub, qbinary_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qadd_scalar_relu_stub, qadd_scalar_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qadd_scalar_stub, qadd_scalar_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qadd_stub, qbinary_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, qavg_pool2d_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, qavg_pool3d_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qbatch_norm_relu_stub, qbatch_norm_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qbatch_norm_stub, qbatch_norm_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qcat_nhwc_stub, qcat_nhwc_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qcat_relu_nhwc_stub, qcat_nhwc_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qclamp_stub, qclamp_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qclamp_min_stub, qclamp_minmax_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qclamp_max_stub, qclamp_minmax_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qelu_stub, qelu_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qhardsigmoid_stub, qhardsigmoid_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qhardswish_stub, qhardswish_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qmaxpool_2d_nhwc_stub, qmaxpool_2d_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub, qbinary_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qmul_stub, qbinary_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qrelu6_stub, qrelu_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub, qrelu_leaky_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qrelu_stub, qrelu_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub, qsigmoid_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qtanh_stub, qtanh_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qthreshold_stub, qthreshold_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qtopk_stub, qtopk_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_channel_stub,
fake_quant_learnable_per_channel_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_stub,
quantize_tensor_per_tensor_affine_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_affine_stub,
quantize_tensor_per_channel_affine_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_float_qparams_stub,
quantize_tensor_per_channel_float_qparams_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(quantized_normalize_stub, qnormalize_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(qupsample_bilinear2d_nhwc_stub, qupsample_bilinear2d_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub,
quantize_tensor_per_tensor_affine_sub_byte_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub,
dequantize_tensor_per_tensor_affine_sub_byte_fn);
#else
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub, REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
&dequantize_tensor_per_channel_affine_cpu); &dequantize_tensor_per_channel_affine_cpu);
@ -3174,7 +3445,8 @@ REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel); REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &fake_quantize_learnable_channel_grad_kernel_cpu); REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
&fake_quantize_learnable_channel_grad_kernel_cpu);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH( REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_stub, quantize_tensor_per_tensor_affine_stub,
@ -3200,7 +3472,7 @@ REGISTER_DISPATCH(
REGISTER_DISPATCH( REGISTER_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_stub, dequantize_tensor_per_tensor_affine_sub_byte_stub,
&dequantize_tensor_per_tensor_affine_sub_byte_cpu); &dequantize_tensor_per_tensor_affine_sub_byte_cpu);
#endif // CPU_CAPABILITY_AVX512 && _WIN32
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -1071,13 +1071,17 @@ namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(ComplexTests, TestComplexFloatImagRealConj) { TEST(ComplexTests, TestComplexFloatImagRealConj) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28 }; float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28,
9.5488e-28,10.5488e-28,11.5488e-28,12.5488e-28,13.5488e-28,14.5488e-28,15.5488e-28,16.5488e-28};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0 }; float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0,aa[8],0,aa[10],0,aa[12],0,aa[14],0 };
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0 }; float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0,aa[9],0,aa[11],0,aa[13],0,aa[15],0 };
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28 }; float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,
5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28,
9.5488e-28,-10.5488e-28,11.5488e-28,-12.5488e-28,
13.5488e-28,-14.5488e-28,15.5488e-28,-16.5488e-28 };
auto a = vcomplex::loadu(aa); auto a = vcomplex::loadu(aa);
auto actual1 = a.real(); auto actual1 = a.real();
auto actual3 = a.imag(); auto actual3 = a.imag();
@ -1304,6 +1308,7 @@ namespace {
}, },
test_case); test_case);
} }
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TYPED_TEST(FunctionalTests, Map) { TYPED_TEST(FunctionalTests, Map) {
using vec = TypeParam; using vec = TypeParam;
using VT = ValueType<TypeParam>; using VT = ValueType<TypeParam>;
@ -1339,15 +1344,16 @@ namespace {
at::vec::map3<VT>([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N); at::vec::map3<VT>([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N);
for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i]; } for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i]; }
cmp(y, ref_y); cmp(y, ref_y);
// test map3: y = x1 + x2 + x3 + x4 // test map4: y = x1 + x2 + x3 + x4
at::vec::map4<VT>([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N); at::vec::map4<VT>([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N);
for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; } for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; }
cmp(y, ref_y); cmp(y, ref_y);
} }
TYPED_TEST(FunctionalBF16Tests, Reduce) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TYPED_TEST(FunctionalBF16Tests, Reduce) {
using vec = TypeParam; using vec = TypeParam;
// Can't use ValueType<TypeParam> here: // Can't use ValueType<TypeParam> here:
// Vectorized<BFloat16>::value_type returns uint16_t on AVX2 // Vectorized<BFloat16>::value_type returns uint16_t on AVX2/AVX512
using VT = c10::BFloat16; using VT = c10::BFloat16;
using RT = float; // reference using RT = float; // reference
constexpr auto R = 2LL; // residual constexpr auto R = 2LL; // residual
@ -1394,7 +1400,6 @@ namespace {
auto y2 = at::vec::map_reduce_all<VT>([](auto x) { return x - x.exp(); }, sum, x_b1, len); auto y2 = at::vec::map_reduce_all<VT>([](auto x) { return x - x.exp(); }, sum, x_b1, len);
ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed
<< "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2); << "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2);
} }
// Map2ReduceAll // Map2ReduceAll
for (int64_t len = 1; len <= N; len++) { for (int64_t len = 1; len <= N; len++) {

View File

@ -13,7 +13,13 @@
#include <math.h> #include <math.h>
#include <float.h> #include <float.h>
#include <algorithm> #include <algorithm>
#if defined(CPU_CAPABILITY_AVX512)
#define CACHE_LINE 64
#else
#define CACHE_LINE 32 #define CACHE_LINE 32
#endif
#if defined(__GNUC__) #if defined(__GNUC__)
#define CACHE_ALIGN __attribute__((aligned(CACHE_LINE))) #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE)))
#define not_inline __attribute__((noinline)) #define not_inline __attribute__((noinline))
@ -26,7 +32,7 @@ CACHE_ALIGN #define
#endif #endif
#if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) #if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER)
#define TEST_AGAINST_DEFAULT 1 #define TEST_AGAINST_DEFAULT 1
#elif !defined(CPU_CAPABILITY_AVX) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) #elif !defined(CPU_CAPABILITY_AVX512) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX)
#define TEST_AGAINST_DEFAULT 1 #define TEST_AGAINST_DEFAULT 1
#else #else
#undef TEST_AGAINST_DEFAULT #undef TEST_AGAINST_DEFAULT
@ -41,7 +47,8 @@ CACHE_ALIGN #define
return __VA_ARGS__(std::forward<decltype(args)>(args)...); \ return __VA_ARGS__(std::forward<decltype(args)>(args)...); \
} }
#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) && (defined(__GNUC__) || defined(__GNUG__)) #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) || \
defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__))
#undef CHECK_DEQUANT_WITH_LOW_PRECISION #undef CHECK_DEQUANT_WITH_LOW_PRECISION
#define CHECK_WITH_FMA 1 #define CHECK_WITH_FMA 1
#elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)

View File

@ -722,44 +722,43 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS}) list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
endif() endif()
# NOTE [ Linking AVX and non-AVX files ] # NOTE [ Linking AVX-n and non-AVX-n files ]
# #
# Regardless of the CPU capabilities, we build some files with AVX and AVX2 # Regardless of the CPU capabilities, we build some files with AVX2, and AVX512
# instruction set. If the host CPU doesn't support those, we simply ignore their # instruction set. If the host CPU doesn't support those, we simply ignore their
# functions at runtime during dispatch. # functions at runtime during dispatch.
# #
# We must make sure that those files are at the end of the input list when # We must make sure that those files are at the end of the input list when
# linking the torch_cpu library. Otherwise, the following error scenario might # linking the torch_cpu library. Otherwise, the following error scenario might
# occur: # occur:
# 1. A non-AVX and an AVX file both call a function defined with the `inline` # 1. A non-AVX2 and an AVX2 file both call a function defined with the `inline`
# keyword # keyword
# 2. The compiler decides not to inline this function # 2. The compiler decides not to inline this function
# 3. Two different versions of the machine code are generated for this function: # 3. Two different versions of the machine code are generated for this function:
# one without AVX instructions and one with AVX. # one without AVX2 instructions and one with AVX2.
# 4. When linking, the AVX version is found earlier in the input object files, # 4. When linking, the AVX2 version is found earlier in the input object files,
# so the linker makes the entire library use it, even in code not guarded by # so the linker makes the entire library use it, even in code not guarded by
# the dispatcher. # the dispatcher.
# 5. A CPU without AVX support executes this function, encounters an AVX # 5. A CPU without AVX2 support executes this function, encounters an AVX2
# instruction and crashes. # instruction and crashes.
# #
# Thus we organize the input files in the following order: # Thus we organize the input files in the following order:
# 1. All files with no AVX support # 1. All files with no AVX-n support
# 2. All files with AVX support (conveniently, they all have names ending with # 2. All files with AVX2 support ('*AVX2.cpp')
# 'AVX.cpp') # 3. All files with AVX512 support ('*AVX512.cpp')
# 3. All files with AVX2 support ('*AVX2.cpp')
set(Caffe2_CPU_SRCS_NON_AVX) set(Caffe2_CPU_SRCS_NON_AVX)
set(Caffe2_CPU_SRCS_AVX)
set(Caffe2_CPU_SRCS_AVX2) set(Caffe2_CPU_SRCS_AVX2)
set(Caffe2_CPU_SRCS_AVX512)
foreach(input_filename ${Caffe2_CPU_SRCS}) foreach(input_filename ${Caffe2_CPU_SRCS})
if(${input_filename} MATCHES "AVX\\.cpp") if(${input_filename} MATCHES "AVX2\\.cpp")
list(APPEND Caffe2_CPU_SRCS_AVX ${input_filename})
elseif(${input_filename} MATCHES "AVX2\\.cpp")
list(APPEND Caffe2_CPU_SRCS_AVX2 ${input_filename}) list(APPEND Caffe2_CPU_SRCS_AVX2 ${input_filename})
elseif(${input_filename} MATCHES "AVX512\\.cpp")
list(APPEND Caffe2_CPU_SRCS_AVX512 ${input_filename})
else() else()
list(APPEND Caffe2_CPU_SRCS_NON_AVX ${input_filename}) list(APPEND Caffe2_CPU_SRCS_NON_AVX ${input_filename})
endif() endif()
endforeach(input_filename) endforeach(input_filename)
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX} ${Caffe2_CPU_SRCS_AVX2}) set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX2} ${Caffe2_CPU_SRCS_AVX512})
# ========================================================== # ==========================================================
# END formerly-libtorch sources # END formerly-libtorch sources

View File

@ -63,14 +63,6 @@ if(INTERN_BUILD_ATEN_OPS)
endif() endif()
endif(MSVC) endif(MSVC)
if(C_AVX_FOUND)
if(MSVC)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG}/arch:AVX ${CXX_AVX_FLAGS}")
else(MSVC)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG} ${CXX_AVX_FLAGS}")
endif(MSVC)
endif(C_AVX_FOUND)
if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang") if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp") set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp")
endif() endif()
@ -80,15 +72,16 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND CPU_CAPABILITY_NAMES "DEFAULT") list(APPEND CPU_CAPABILITY_NAMES "DEFAULT")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}") list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}")
if(CXX_AVX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX_CPU_DEFINITION") if(CXX_AVX512_FOUND)
list(APPEND CPU_CAPABILITY_NAMES "AVX") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "AVX512")
if(MSVC) if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX") list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
else(MSVC) else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx") list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma")
endif(MSVC) endif(MSVC)
endif(CXX_AVX_FOUND) endif(CXX_AVX512_FOUND)
if(CXX_AVX2_FOUND) if(CXX_AVX2_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION")
@ -103,11 +96,24 @@ if(INTERN_BUILD_ATEN_OPS)
endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT) endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT)
list(APPEND CPU_CAPABILITY_NAMES "AVX2") list(APPEND CPU_CAPABILITY_NAMES "AVX2")
if(MSVC) if(DEFINED ENV{ATEN_AVX512_256})
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2") if($ENV{ATEN_AVX512_256} MATCHES "TRUE")
else(MSVC) if(CXX_AVX512_FOUND)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}") message("-- ATen AVX2 kernels will use 32 ymm registers")
endif(MSVC) if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=native ${CPU_NO_AVX256_SPLIT_FLAGS}")
endif(MSVC)
endif(CXX_AVX512_FOUND)
endif()
else()
if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2")
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}")
endif(MSVC)
endif()
endif(CXX_AVX2_FOUND) endif(CXX_AVX2_FOUND)
if(CXX_VSX_FOUND) if(CXX_VSX_FOUND)

View File

@ -12,6 +12,25 @@ SET(AVX_CODE "
} }
") ")
SET(AVX512_CODE "
#include <immintrin.h>
int main()
{
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0);
__m512i b = a;
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
return 0;
}
")
SET(AVX2_CODE " SET(AVX2_CODE "
#include <immintrin.h> #include <immintrin.h>
@ -56,6 +75,8 @@ ENDMACRO()
CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX") CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX")
CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2") CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX") CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX")
CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2") CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")

View File

@ -99,6 +99,12 @@
# BUILD_BINARY # BUILD_BINARY
# enables the additional binaries/ build # enables the additional binaries/ build
# #
# ATEN_AVX512_256=TRUE
# ATen AVX2 kernels can use 32 ymm registers, instead of the default 16.
# This option can be used if AVX512 doesn't perform well on a machine.
# The FBGEMM library also uses AVX512_256 kernels on Xeon D processors,
# but it also has some (optimized) assembly code.
#
# PYTORCH_BUILD_VERSION # PYTORCH_BUILD_VERSION
# PYTORCH_BUILD_NUMBER # PYTORCH_BUILD_NUMBER
# specify the version of PyTorch, rather than the hard-coded version # specify the version of PyTorch, rather than the hard-coded version
@ -928,6 +934,7 @@ if __name__ == '__main__':
'include/ATen/*.h', 'include/ATen/*.h',
'include/ATen/cpu/*.h', 'include/ATen/cpu/*.h',
'include/ATen/cpu/vec/vec256/*.h', 'include/ATen/cpu/vec/vec256/*.h',
'include/ATen/cpu/vec/vec512/*.h',
'include/ATen/cpu/vec/*.h', 'include/ATen/cpu/vec/*.h',
'include/ATen/core/*.h', 'include/ATen/core/*.h',
'include/ATen/cuda/*.cuh', 'include/ATen/cuda/*.cuh',

View File

@ -29,19 +29,19 @@ TEST_F(DispatchTest, TestAVX2) {
} }
} }
TEST_F(DispatchTest, TestAVX) { TEST_F(DispatchTest, TestAVX512) {
const std::vector<int> ints {1, 2, 3, 4}; const std::vector<int> ints {1, 2, 3, 4};
const std::vector<int> result {1, 4, 27, 256}; const std::vector<int> result {1, 4, 27, 256};
const auto vals_tensor = torch::tensor(ints); const auto vals_tensor = torch::tensor(ints);
const auto pows_tensor = torch::tensor(ints); const auto pows_tensor = torch::tensor(ints);
#ifdef _WIN32 #ifdef _WIN32
_putenv("ATEN_CPU_CAPABILITY=avx"); _putenv("ATEN_CPU_CAPABILITY=avx512");
#else #else
setenv("ATEN_CPU_CAPABILITY", "avx", 1); setenv("ATEN_CPU_CAPABILITY", "avx512", 1);
#endif #endif
const auto actual_pow_avx = vals_tensor.pow(pows_tensor); const auto actual_pow_avx512 = vals_tensor.pow(pows_tensor);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
ASSERT_EQ(result[i], actual_pow_avx[i].item<int>()); ASSERT_EQ(result[i], actual_pow_avx512[i].item<int>());
} }
} }

View File

@ -2,7 +2,7 @@
import sys import sys
import os import os
import unittest
# torch # torch
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -11,7 +11,7 @@ import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq import torch.nn.intrinsic.quantized as nniq
# Testing utils # Testing utils
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED
from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm
def remove_prefix(text, prefix): def remove_prefix(text, prefix):
@ -238,6 +238,7 @@ class TestSerialization(TestCase):
# TODO: graph mode quantized conv3d module # TODO: graph mode quantized conv3d module
@override_qengines @override_qengines
@unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098")
def test_lstm(self): def test_lstm(self):
class LSTMModule(torch.nn.Module): class LSTMModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -339,6 +339,15 @@ IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin" IS_MACOS = sys.platform == "darwin"
IS_PPC = platform.machine() == "ppc64le" IS_PPC = platform.machine() == "ppc64le"
def is_avx512_vnni_supported():
if sys.platform != 'linux':
return False
with open("/proc/cpuinfo", encoding="ascii") as f:
lines = f.read()
return "avx512vnni" in lines
IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported()
if IS_WINDOWS: if IS_WINDOWS:
@contextmanager @contextmanager
def TemporaryFileName(*args, **kwargs): def TemporaryFileName(*args, **kwargs):